Repository: snap-research/stable-flow Branch: main Commit: c99ac0eba297 Files: 478 Total size: 10.6 MB Directory structure: gitextract_693itit8/ ├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── run_stable_flow.py ├── setup.py └── src/ └── diffusers/ ├── __init__.py ├── callbacks.py ├── commands/ │ ├── __init__.py │ ├── diffusers_cli.py │ ├── env.py │ └── fp16_safetensors.py ├── configuration_utils.py ├── dependency_versions_check.py ├── dependency_versions_table.py ├── experimental/ │ ├── README.md │ ├── __init__.py │ └── rl/ │ ├── __init__.py │ └── value_guided_sampling.py ├── image_processor.py ├── loaders/ │ ├── __init__.py │ ├── ip_adapter.py │ ├── lora_base.py │ ├── lora_conversion_utils.py │ ├── lora_pipeline.py │ ├── peft.py │ ├── single_file.py │ ├── single_file_model.py │ ├── single_file_utils.py │ ├── textual_inversion.py │ ├── unet.py │ ├── unet_loader_utils.py │ └── utils.py ├── models/ │ ├── README.md │ ├── __init__.py │ ├── activations.py │ ├── adapter.py │ ├── attention.py │ ├── attention_flax.py │ ├── attention_processor.py │ ├── autoencoders/ │ │ ├── __init__.py │ │ ├── autoencoder_asym_kl.py │ │ ├── autoencoder_kl.py │ │ ├── autoencoder_kl_cogvideox.py │ │ ├── autoencoder_kl_temporal_decoder.py │ │ ├── autoencoder_oobleck.py │ │ ├── autoencoder_tiny.py │ │ ├── consistency_decoder_vae.py │ │ ├── vae.py │ │ └── vq_model.py │ ├── controlnet.py │ ├── controlnet_flax.py │ ├── controlnet_hunyuan.py │ ├── controlnet_sd3.py │ ├── controlnet_sparsectrl.py │ ├── controlnet_xs.py │ ├── downsampling.py │ ├── embeddings.py │ ├── embeddings_flax.py │ ├── lora.py │ ├── model_loading_utils.py │ ├── modeling_flax_pytorch_utils.py │ ├── modeling_flax_utils.py │ ├── modeling_outputs.py │ ├── modeling_pytorch_flax_utils.py │ ├── modeling_utils.py │ ├── normalization.py │ ├── resnet.py │ ├── resnet_flax.py │ ├── transformers/ │ │ ├── __init__.py │ │ ├── auraflow_transformer_2d.py │ │ ├── cogvideox_transformer_3d.py │ │ ├── dit_transformer_2d.py │ │ ├── dual_transformer_2d.py │ │ ├── hunyuan_transformer_2d.py │ │ ├── latte_transformer_3d.py │ │ ├── lumina_nextdit2d.py │ │ ├── pixart_transformer_2d.py │ │ ├── prior_transformer.py │ │ ├── stable_audio_transformer.py │ │ ├── t5_film_transformer.py │ │ ├── transformer_2d.py │ │ ├── transformer_flux.py │ │ ├── transformer_sd3.py │ │ └── transformer_temporal.py │ ├── unets/ │ │ ├── __init__.py │ │ ├── unet_1d.py │ │ ├── unet_1d_blocks.py │ │ ├── unet_2d.py │ │ ├── unet_2d_blocks.py │ │ ├── unet_2d_blocks_flax.py │ │ ├── unet_2d_condition.py │ │ ├── unet_2d_condition_flax.py │ │ ├── unet_3d_blocks.py │ │ ├── unet_3d_condition.py │ │ ├── unet_i2vgen_xl.py │ │ ├── unet_kandinsky3.py │ │ ├── unet_motion_model.py │ │ ├── unet_spatio_temporal_condition.py │ │ ├── unet_stable_cascade.py │ │ └── uvit_2d.py │ ├── upsampling.py │ ├── vae_flax.py │ └── vq_model.py ├── optimization.py ├── pipelines/ │ ├── README.md │ ├── __init__.py │ ├── amused/ │ │ ├── __init__.py │ │ ├── pipeline_amused.py │ │ ├── pipeline_amused_img2img.py │ │ └── pipeline_amused_inpaint.py │ ├── animatediff/ │ │ ├── __init__.py │ │ ├── pipeline_animatediff.py │ │ ├── pipeline_animatediff_controlnet.py │ │ ├── pipeline_animatediff_sdxl.py │ │ ├── pipeline_animatediff_sparsectrl.py │ │ ├── pipeline_animatediff_video2video.py │ │ └── pipeline_output.py │ ├── audioldm/ │ │ ├── __init__.py │ │ └── pipeline_audioldm.py │ ├── audioldm2/ │ │ ├── __init__.py │ │ ├── modeling_audioldm2.py │ │ └── pipeline_audioldm2.py │ ├── aura_flow/ │ │ ├── __init__.py │ │ └── pipeline_aura_flow.py │ ├── auto_pipeline.py │ ├── blip_diffusion/ │ │ ├── __init__.py │ │ ├── blip_image_processing.py │ │ ├── modeling_blip2.py │ │ ├── modeling_ctx_clip.py │ │ └── pipeline_blip_diffusion.py │ ├── cogvideo/ │ │ ├── __init__.py │ │ └── pipeline_cogvideox.py │ ├── consistency_models/ │ │ ├── __init__.py │ │ └── pipeline_consistency_models.py │ ├── controlnet/ │ │ ├── __init__.py │ │ ├── multicontrolnet.py │ │ ├── pipeline_controlnet.py │ │ ├── pipeline_controlnet_blip_diffusion.py │ │ ├── pipeline_controlnet_img2img.py │ │ ├── pipeline_controlnet_inpaint.py │ │ ├── pipeline_controlnet_inpaint_sd_xl.py │ │ ├── pipeline_controlnet_sd_xl.py │ │ ├── pipeline_controlnet_sd_xl_img2img.py │ │ └── pipeline_flax_controlnet.py │ ├── controlnet_hunyuandit/ │ │ ├── __init__.py │ │ └── pipeline_hunyuandit_controlnet.py │ ├── controlnet_sd3/ │ │ ├── __init__.py │ │ └── pipeline_stable_diffusion_3_controlnet.py │ ├── controlnet_xs/ │ │ ├── __init__.py │ │ ├── pipeline_controlnet_xs.py │ │ └── pipeline_controlnet_xs_sd_xl.py │ ├── dance_diffusion/ │ │ ├── __init__.py │ │ └── pipeline_dance_diffusion.py │ ├── ddim/ │ │ ├── __init__.py │ │ └── pipeline_ddim.py │ ├── ddpm/ │ │ ├── __init__.py │ │ └── pipeline_ddpm.py │ ├── deepfloyd_if/ │ │ ├── __init__.py │ │ ├── pipeline_if.py │ │ ├── pipeline_if_img2img.py │ │ ├── pipeline_if_img2img_superresolution.py │ │ ├── pipeline_if_inpainting.py │ │ ├── pipeline_if_inpainting_superresolution.py │ │ ├── pipeline_if_superresolution.py │ │ ├── pipeline_output.py │ │ ├── safety_checker.py │ │ ├── timesteps.py │ │ └── watermark.py │ ├── deprecated/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── alt_diffusion/ │ │ │ ├── __init__.py │ │ │ ├── modeling_roberta_series.py │ │ │ ├── pipeline_alt_diffusion.py │ │ │ ├── pipeline_alt_diffusion_img2img.py │ │ │ └── pipeline_output.py │ │ ├── audio_diffusion/ │ │ │ ├── __init__.py │ │ │ ├── mel.py │ │ │ └── pipeline_audio_diffusion.py │ │ ├── latent_diffusion_uncond/ │ │ │ ├── __init__.py │ │ │ └── pipeline_latent_diffusion_uncond.py │ │ ├── pndm/ │ │ │ ├── __init__.py │ │ │ └── pipeline_pndm.py │ │ ├── repaint/ │ │ │ ├── __init__.py │ │ │ └── pipeline_repaint.py │ │ ├── score_sde_ve/ │ │ │ ├── __init__.py │ │ │ └── pipeline_score_sde_ve.py │ │ ├── spectrogram_diffusion/ │ │ │ ├── __init__.py │ │ │ ├── continuous_encoder.py │ │ │ ├── midi_utils.py │ │ │ ├── notes_encoder.py │ │ │ └── pipeline_spectrogram_diffusion.py │ │ ├── stable_diffusion_variants/ │ │ │ ├── __init__.py │ │ │ ├── pipeline_cycle_diffusion.py │ │ │ ├── pipeline_onnx_stable_diffusion_inpaint_legacy.py │ │ │ ├── pipeline_stable_diffusion_inpaint_legacy.py │ │ │ ├── pipeline_stable_diffusion_model_editing.py │ │ │ ├── pipeline_stable_diffusion_paradigms.py │ │ │ └── pipeline_stable_diffusion_pix2pix_zero.py │ │ ├── stochastic_karras_ve/ │ │ │ ├── __init__.py │ │ │ └── pipeline_stochastic_karras_ve.py │ │ ├── versatile_diffusion/ │ │ │ ├── __init__.py │ │ │ ├── modeling_text_unet.py │ │ │ ├── pipeline_versatile_diffusion.py │ │ │ ├── pipeline_versatile_diffusion_dual_guided.py │ │ │ ├── pipeline_versatile_diffusion_image_variation.py │ │ │ └── pipeline_versatile_diffusion_text_to_image.py │ │ └── vq_diffusion/ │ │ ├── __init__.py │ │ └── pipeline_vq_diffusion.py │ ├── dit/ │ │ ├── __init__.py │ │ └── pipeline_dit.py │ ├── flux/ │ │ ├── __init__.py │ │ ├── pipeline_flux.py │ │ └── pipeline_output.py │ ├── free_init_utils.py │ ├── free_noise_utils.py │ ├── hunyuandit/ │ │ ├── __init__.py │ │ └── pipeline_hunyuandit.py │ ├── i2vgen_xl/ │ │ ├── __init__.py │ │ └── pipeline_i2vgen_xl.py │ ├── kandinsky/ │ │ ├── __init__.py │ │ ├── pipeline_kandinsky.py │ │ ├── pipeline_kandinsky_combined.py │ │ ├── pipeline_kandinsky_img2img.py │ │ ├── pipeline_kandinsky_inpaint.py │ │ ├── pipeline_kandinsky_prior.py │ │ └── text_encoder.py │ ├── kandinsky2_2/ │ │ ├── __init__.py │ │ ├── pipeline_kandinsky2_2.py │ │ ├── pipeline_kandinsky2_2_combined.py │ │ ├── pipeline_kandinsky2_2_controlnet.py │ │ ├── pipeline_kandinsky2_2_controlnet_img2img.py │ │ ├── pipeline_kandinsky2_2_img2img.py │ │ ├── pipeline_kandinsky2_2_inpainting.py │ │ ├── pipeline_kandinsky2_2_prior.py │ │ └── pipeline_kandinsky2_2_prior_emb2emb.py │ ├── kandinsky3/ │ │ ├── __init__.py │ │ ├── convert_kandinsky3_unet.py │ │ ├── pipeline_kandinsky3.py │ │ └── pipeline_kandinsky3_img2img.py │ ├── kolors/ │ │ ├── __init__.py │ │ ├── pipeline_kolors.py │ │ ├── pipeline_kolors_img2img.py │ │ ├── pipeline_output.py │ │ ├── text_encoder.py │ │ └── tokenizer.py │ ├── latent_consistency_models/ │ │ ├── __init__.py │ │ ├── pipeline_latent_consistency_img2img.py │ │ └── pipeline_latent_consistency_text2img.py │ ├── latent_diffusion/ │ │ ├── __init__.py │ │ ├── pipeline_latent_diffusion.py │ │ └── pipeline_latent_diffusion_superresolution.py │ ├── latte/ │ │ ├── __init__.py │ │ └── pipeline_latte.py │ ├── ledits_pp/ │ │ ├── __init__.py │ │ ├── pipeline_leditspp_stable_diffusion.py │ │ ├── pipeline_leditspp_stable_diffusion_xl.py │ │ └── pipeline_output.py │ ├── lumina/ │ │ ├── __init__.py │ │ └── pipeline_lumina.py │ ├── marigold/ │ │ ├── __init__.py │ │ ├── marigold_image_processing.py │ │ ├── pipeline_marigold_depth.py │ │ └── pipeline_marigold_normals.py │ ├── musicldm/ │ │ ├── __init__.py │ │ └── pipeline_musicldm.py │ ├── onnx_utils.py │ ├── pag/ │ │ ├── __init__.py │ │ ├── pag_utils.py │ │ ├── pipeline_pag_controlnet_sd.py │ │ ├── pipeline_pag_controlnet_sd_xl.py │ │ ├── pipeline_pag_hunyuandit.py │ │ ├── pipeline_pag_kolors.py │ │ ├── pipeline_pag_pixart_sigma.py │ │ ├── pipeline_pag_sd.py │ │ ├── pipeline_pag_sd_3.py │ │ ├── pipeline_pag_sd_animatediff.py │ │ ├── pipeline_pag_sd_xl.py │ │ ├── pipeline_pag_sd_xl_img2img.py │ │ └── pipeline_pag_sd_xl_inpaint.py │ ├── paint_by_example/ │ │ ├── __init__.py │ │ ├── image_encoder.py │ │ └── pipeline_paint_by_example.py │ ├── pia/ │ │ ├── __init__.py │ │ └── pipeline_pia.py │ ├── pipeline_flax_utils.py │ ├── pipeline_loading_utils.py │ ├── pipeline_utils.py │ ├── pixart_alpha/ │ │ ├── __init__.py │ │ ├── pipeline_pixart_alpha.py │ │ └── pipeline_pixart_sigma.py │ ├── semantic_stable_diffusion/ │ │ ├── __init__.py │ │ ├── pipeline_output.py │ │ └── pipeline_semantic_stable_diffusion.py │ ├── shap_e/ │ │ ├── __init__.py │ │ ├── camera.py │ │ ├── pipeline_shap_e.py │ │ ├── pipeline_shap_e_img2img.py │ │ └── renderer.py │ ├── stable_audio/ │ │ ├── __init__.py │ │ ├── modeling_stable_audio.py │ │ └── pipeline_stable_audio.py │ ├── stable_cascade/ │ │ ├── __init__.py │ │ ├── pipeline_stable_cascade.py │ │ ├── pipeline_stable_cascade_combined.py │ │ └── pipeline_stable_cascade_prior.py │ ├── stable_diffusion/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── clip_image_project_model.py │ │ ├── convert_from_ckpt.py │ │ ├── pipeline_flax_stable_diffusion.py │ │ ├── pipeline_flax_stable_diffusion_img2img.py │ │ ├── pipeline_flax_stable_diffusion_inpaint.py │ │ ├── pipeline_onnx_stable_diffusion.py │ │ ├── pipeline_onnx_stable_diffusion_img2img.py │ │ ├── pipeline_onnx_stable_diffusion_inpaint.py │ │ ├── pipeline_onnx_stable_diffusion_upscale.py │ │ ├── pipeline_output.py │ │ ├── pipeline_stable_diffusion.py │ │ ├── pipeline_stable_diffusion_depth2img.py │ │ ├── pipeline_stable_diffusion_image_variation.py │ │ ├── pipeline_stable_diffusion_img2img.py │ │ ├── pipeline_stable_diffusion_inpaint.py │ │ ├── pipeline_stable_diffusion_instruct_pix2pix.py │ │ ├── pipeline_stable_diffusion_latent_upscale.py │ │ ├── pipeline_stable_diffusion_upscale.py │ │ ├── pipeline_stable_unclip.py │ │ ├── pipeline_stable_unclip_img2img.py │ │ ├── safety_checker.py │ │ ├── safety_checker_flax.py │ │ └── stable_unclip_image_normalizer.py │ ├── stable_diffusion_3/ │ │ ├── __init__.py │ │ ├── pipeline_output.py │ │ ├── pipeline_stable_diffusion_3.py │ │ ├── pipeline_stable_diffusion_3_img2img.py │ │ └── pipeline_stable_diffusion_3_inpaint.py │ ├── stable_diffusion_attend_and_excite/ │ │ ├── __init__.py │ │ └── pipeline_stable_diffusion_attend_and_excite.py │ ├── stable_diffusion_diffedit/ │ │ ├── __init__.py │ │ └── pipeline_stable_diffusion_diffedit.py │ ├── stable_diffusion_gligen/ │ │ ├── __init__.py │ │ ├── pipeline_stable_diffusion_gligen.py │ │ └── pipeline_stable_diffusion_gligen_text_image.py │ ├── stable_diffusion_k_diffusion/ │ │ ├── __init__.py │ │ ├── pipeline_stable_diffusion_k_diffusion.py │ │ └── pipeline_stable_diffusion_xl_k_diffusion.py │ ├── stable_diffusion_ldm3d/ │ │ ├── __init__.py │ │ └── pipeline_stable_diffusion_ldm3d.py │ ├── stable_diffusion_panorama/ │ │ ├── __init__.py │ │ └── pipeline_stable_diffusion_panorama.py │ ├── stable_diffusion_safe/ │ │ ├── __init__.py │ │ ├── pipeline_output.py │ │ ├── pipeline_stable_diffusion_safe.py │ │ └── safety_checker.py │ ├── stable_diffusion_sag/ │ │ ├── __init__.py │ │ └── pipeline_stable_diffusion_sag.py │ ├── stable_diffusion_xl/ │ │ ├── __init__.py │ │ ├── pipeline_flax_stable_diffusion_xl.py │ │ ├── pipeline_output.py │ │ ├── pipeline_stable_diffusion_xl.py │ │ ├── pipeline_stable_diffusion_xl_img2img.py │ │ ├── pipeline_stable_diffusion_xl_inpaint.py │ │ ├── pipeline_stable_diffusion_xl_instruct_pix2pix.py │ │ └── watermark.py │ ├── stable_video_diffusion/ │ │ ├── __init__.py │ │ └── pipeline_stable_video_diffusion.py │ ├── t2i_adapter/ │ │ ├── __init__.py │ │ ├── pipeline_stable_diffusion_adapter.py │ │ └── pipeline_stable_diffusion_xl_adapter.py │ ├── text_to_video_synthesis/ │ │ ├── __init__.py │ │ ├── pipeline_output.py │ │ ├── pipeline_text_to_video_synth.py │ │ ├── pipeline_text_to_video_synth_img2img.py │ │ ├── pipeline_text_to_video_zero.py │ │ └── pipeline_text_to_video_zero_sdxl.py │ ├── unclip/ │ │ ├── __init__.py │ │ ├── pipeline_unclip.py │ │ ├── pipeline_unclip_image_variation.py │ │ └── text_proj.py │ ├── unidiffuser/ │ │ ├── __init__.py │ │ ├── modeling_text_decoder.py │ │ ├── modeling_uvit.py │ │ └── pipeline_unidiffuser.py │ └── wuerstchen/ │ ├── __init__.py │ ├── modeling_paella_vq_model.py │ ├── modeling_wuerstchen_common.py │ ├── modeling_wuerstchen_diffnext.py │ ├── modeling_wuerstchen_prior.py │ ├── pipeline_wuerstchen.py │ ├── pipeline_wuerstchen_combined.py │ └── pipeline_wuerstchen_prior.py ├── py.typed ├── schedulers/ │ ├── README.md │ ├── __init__.py │ ├── deprecated/ │ │ ├── __init__.py │ │ ├── scheduling_karras_ve.py │ │ └── scheduling_sde_vp.py │ ├── scheduling_amused.py │ ├── scheduling_consistency_decoder.py │ ├── scheduling_consistency_models.py │ ├── scheduling_cosine_dpmsolver_multistep.py │ ├── scheduling_ddim.py │ ├── scheduling_ddim_cogvideox.py │ ├── scheduling_ddim_flax.py │ ├── scheduling_ddim_inverse.py │ ├── scheduling_ddim_parallel.py │ ├── scheduling_ddpm.py │ ├── scheduling_ddpm_flax.py │ ├── scheduling_ddpm_parallel.py │ ├── scheduling_ddpm_wuerstchen.py │ ├── scheduling_deis_multistep.py │ ├── scheduling_dpm_cogvideox.py │ ├── scheduling_dpmsolver_multistep.py │ ├── scheduling_dpmsolver_multistep_flax.py │ ├── scheduling_dpmsolver_multistep_inverse.py │ ├── scheduling_dpmsolver_sde.py │ ├── scheduling_dpmsolver_singlestep.py │ ├── scheduling_edm_dpmsolver_multistep.py │ ├── scheduling_edm_euler.py │ ├── scheduling_euler_ancestral_discrete.py │ ├── scheduling_euler_discrete.py │ ├── scheduling_euler_discrete_flax.py │ ├── scheduling_flow_match_euler_discrete.py │ ├── scheduling_flow_match_heun_discrete.py │ ├── scheduling_heun_discrete.py │ ├── scheduling_ipndm.py │ ├── scheduling_k_dpm_2_ancestral_discrete.py │ ├── scheduling_k_dpm_2_discrete.py │ ├── scheduling_karras_ve_flax.py │ ├── scheduling_lcm.py │ ├── scheduling_lms_discrete.py │ ├── scheduling_lms_discrete_flax.py │ ├── scheduling_pndm.py │ ├── scheduling_pndm_flax.py │ ├── scheduling_repaint.py │ ├── scheduling_sasolver.py │ ├── scheduling_sde_ve.py │ ├── scheduling_sde_ve_flax.py │ ├── scheduling_tcd.py │ ├── scheduling_unclip.py │ ├── scheduling_unipc_multistep.py │ ├── scheduling_utils.py │ ├── scheduling_utils_flax.py │ └── scheduling_vq_diffusion.py ├── training_utils.py ├── utils/ │ ├── __init__.py │ ├── accelerate_utils.py │ ├── constants.py │ ├── deprecation_utils.py │ ├── doc_utils.py │ ├── dummy_flax_and_transformers_objects.py │ ├── dummy_flax_objects.py │ ├── dummy_note_seq_objects.py │ ├── dummy_onnx_objects.py │ ├── dummy_pt_objects.py │ ├── dummy_torch_and_librosa_objects.py │ ├── dummy_torch_and_scipy_objects.py │ ├── dummy_torch_and_torchsde_objects.py │ ├── dummy_torch_and_transformers_and_k_diffusion_objects.py │ ├── dummy_torch_and_transformers_and_onnx_objects.py │ ├── dummy_torch_and_transformers_and_sentencepiece_objects.py │ ├── dummy_torch_and_transformers_objects.py │ ├── dummy_transformers_and_torch_and_note_seq_objects.py │ ├── dynamic_modules_utils.py │ ├── export_utils.py │ ├── hub_utils.py │ ├── import_utils.py │ ├── loading_utils.py │ ├── logging.py │ ├── model_card_template.md │ ├── outputs.py │ ├── peft_utils.py │ ├── pil_utils.py │ ├── state_dict_utils.py │ ├── testing_utils.py │ ├── torch_utils.py │ └── versions.py └── video_processor.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Initially taken from GitHub's Python gitignore file # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # tests and logs tests/fixtures/cached_*_text.txt logs/ lightning_logs/ lang_code_data/ # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a Python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # vscode .vs .vscode # Pycharm .idea # TF code tensorflow_code # Models proc_data # examples runs /runs_old /wandb /examples/runs /examples/**/*.args /examples/rag/sweep # data /data serialization_dir # emacs *.*~ debug.env # vim .*.swp # ctags tags # pre-commit .pre-commit* # .lock *.lock # DS_Store (MacOS) .DS_Store # RL pipelines may produce mp4 outputs *.mp4 # dependencies /transformers # ruff .ruff_cache # wandb wandb # Projest stuff outputs/ ================================================ FILE: LICENSE ================================================ Copyright (c) 2025 Snap Inc. Copyright (c) 2025 The Hebrew University of Jerusalem Copyright (c) 2025 Tel Aviv University Copyright (c) 2025 Reichman University This sample code is made available for non-commercial, research purposes only. Non-commercial means not primarily intended for or directed towards commercial advantage or monetary compensation. Research purposes mean solely for study, instruction, or non-commercial research, testing or validation. Furthermore, this sample code incorporates the FLUX.1 [dev] model licensed under the FLUX.1 [dev] Non-Commercial License by Black Forest Labs copied below. Please note the further use restrictions outlined in Section 4. No commercial license, whether implied or otherwise, is granted in or to this code, unless you have entered into separate agreements with the authors and/or Black Forest Labs for such rights. This sample code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, title, fitness for a particular purpose, non-infringement, or that the code is free of defects, errors or viruses. In no event will the authors be liable for any damages or losses of any kind arising from this sample code or your use thereof. Any redistribution of this sample code, including in binary form, must retain or reproduce the above copyright notice, conditions and disclaimer, as well as the FLUX.1 [dev] Non-Commercial license text copied below. — FLUX.1 [dev] Model - Non-Commercial License Black Forest Labs, Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] AI models, including FLUX.1 [dev], FLUX.1 Fill [dev], FLUX.1 Depth [dev], FLUX.1 Canny [dev], FLUX.1 Redux [dev], FLUX.1 Canny [dev] LoRA and FLUX.1 Depth [dev] LoRA, and their elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI models made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”). By downloading, accessing, use, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity. 1. Definitions. Capitalized terms used in this License but not defined herein have the following meanings: a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License. b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be. c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the model or its output: (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment, (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use for revenue-generating activity or direct interactions with or impacts on end users, or use to train, fine tune or distill other models for commercial use is not a Non-Commercial purpose. d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from a prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of a FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters. e. “you” or “your” means the individual or entity entering into this License with Company. 2. License Grant. a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein in regarding the FLUX.1 [dev] Model also applies to any Derivative you create or that are created on your behalf. b. Non-Commercial Use Only. You may only access, use, Distribute, or creative Derivatives of or the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If You want to use a FLUX.1 [dev] Model a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please contact Company at the following e-mail address if you want to discuss such a license: info@blackforestlabs.ai. c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License. d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model. 3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions: a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License; b. you must make prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”): “The FLUX.1 [dev] Model is licensed by Black Forest Labs. Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs. Inc. IN NO EVENT SHALL BLACK FOREST LABS, INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.” c. in the case of Distribution of Derivatives made by you, you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; and d. in the case of Distribution of Derivatives made by you, any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions. e. In the case of Distribution of Derivatives made by you, you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing. 4. Restrictions. You will not, and will not permit, assist or cause any third party to a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, for (i) any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates any third-party rights, or (vi) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and all laws governing the processing of biometric information), as well as all amendments and successor laws to any of the foregoing; b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model; c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model; or d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License. e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model; f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (a) to any individual, entity, or country prohibited by Export Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; 3) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (4) will not disguise your location through IP proxying or other methods. 5. DISCLAIMERS. THE FLUX.1 [dev] MODEL IS PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS. 6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE. 7. INDEMNIFICATION You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (as well as any Output, results or data generated from such access or use), including any High-Risk Use (defined below); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties. 8. Termination; Survival. a. This License will automatically terminate upon any breach by you of the terms of this License. b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you. c. If You initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model or any Derivative, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated. d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model and any Derivatives. The following sections survive termination of this License 2(c), 2(d), 4-11. 9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk. 10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name or mark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators. 11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the Documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter. No change or addition to any provision of this License will be binding unless it is in writing and signed by an authorized representative of both you and Company. ================================================ FILE: README.md ================================================ # Stable Flow: Vital Layers for Training-Free Image Editing [CVPR 2025] > **Stable Flow: Vital Layers for Training-Free Image Editing** > > Omri Avrahami, Or Patashnik, Ohad Fried, Egor Nemchinov, Kfir Aberman, Dani Lischinski, Daniel Cohen-Or > > Abstract: Diffusion models have revolutionized the field of content synthesis and editing. Recent models have replaced the traditional UNet architecture with the Diffusion Transformer (DiT), and employed flow-matching for improved training and sampling. However, they exhibit limited generation diversity. In this work, we leverage this limitation to perform consistent image edits via selective injection of attention features. The main challenge is that, unlike the UNet-based models, DiT lacks a coarse-to-fine synthesis structure, making it unclear in which layers to perform the injection. Therefore, we propose an automatic method to identify "vital layers" within DiT, crucial for image formation, and demonstrate how these layers facilitate a range of controlled stable edits, from non-rigid modifications to object addition, using the same mechanism. Next, to enable real-image editing, we introduce an improved image inversion method for flow models. Finally, we evaluate our approach through qualitative and quantitative comparisons, along with a user study, and demonstrate its effectiveness across multiple applications.


A training-free editing method that is able to perform various types of image editing operations, including non-rigid editing, object addition, object removal, and global scene editing. These different edits are done using the same mechanism.

# Applications ## Incremental Editing ## Consistent Style ## Text Editing # Installation Install the conda virtual environment: ```bash conda env create -f environment.yml conda activate stable-flow ``` You may need to update to your own `pytorch-cuda` version. # Usage ## Generated Images Editing You need to provide a list of prompts, where the first prompt describes the original scene, and the rest of the prompts describe the edited scene. For example: ```python [ "A photo of a dog in standing the street", "A photo of a dog sitting in the street", "A photo of a dog in standing and wearing a straw hat the street", "A photo of a mink", ] ``` Then, you can generate a batch of the image and its edited versions using: ```bash python run_stable_flow.py \ --hf_token YOUR_PERSONAL_HUGGINGFACE_TOKEN \ --prompts "A photo of a dog in standing the street" \ "A photo of a dog sitting in the street" \ "A photo of a dog in standing and wearing a straw hat the street" \ "A photo of a mink" ``` where `YOUR_PERSONAL_HUGGINGFACE_TOKEN` is your [personal HuggingFace user access token](https://huggingface.co/docs/hub/en/security-tokens). Then, the results will be saved to `outputs/result.jpg`, or any other path you specify under `--output_path`. ### Low GPU VRAM Usage All the experiments in our paper were conducted using a single NVIDIA V100 GPU of 80GB VRAM. If you are using a GPU with smaller VRAM and getting `CUDA out of memory` errors, you can use [CPU offloading](https://huggingface.co/docs/diffusers/en/optimization/memory) by adding the `--cpu_offload` flag. It will reduce memory consumption but it will increase the inference time significantly. ## Real Images Editing (Optional) If you want to edit a real image, you still need to provide a list of prompts, where the first prompt describes the input image and the rest of the prompts describe the edited scene. For example: ```bash python run_stable_flow.py \ --hf_token YOUR_PERSONAL_HUGGINGFACE_TOKEN \ --input_img_path inputs/bottle.jpg \ --prompts "A photo of a bottle" \ "A photo of a bottle next to an apple" ``` where `YOUR_PERSONAL_HUGGINGFACE_TOKEN` is your [personal HuggingFace user access token](https://huggingface.co/docs/hub/en/security-tokens), and `--input_img_path` is the path to the input image. Then, the results will be saved to `outputs/result.jpg`, or any other path you specify under `--output_path`. ## Citation If you find this useful for your research, please cite the following: ```bibtex @InProceedings{Avrahami_2025_CVPR, author = {Avrahami, Omri and Patashnik, Or and Fried, Ohad and Nemchinov, Egor and Aberman, Kfir and Lischinski, Dani and Cohen-Or, Daniel}, title = {Stable Flow: Vital Layers for Training-Free Image Editing}, booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)}, month = {June}, year = {2025}, pages = {7877-7888} } ``` ================================================ FILE: environment.yml ================================================ name: stable-flow channels: - pytorch - nvidia - defaults dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=5.1=1_gnu - blas=1.0=mkl - brotli-python=1.0.9=py311h6a678d5_8 - bzip2=1.0.8=h5eee18b_6 - ca-certificates=2024.7.2=h06a4308_0 - certifi=2024.8.30=py311h06a4308_0 - charset-normalizer=3.3.2=pyhd3eb1b0_0 - cuda-cudart=12.1.105=0 - cuda-cupti=12.1.105=0 - cuda-libraries=12.1.0=0 - cuda-nvrtc=12.1.105=0 - cuda-nvtx=12.1.105=0 - cuda-opencl=12.6.68=0 - cuda-runtime=12.1.0=0 - cuda-version=12.6=3 - ffmpeg=4.3=hf484d3e_0 - filelock=3.13.1=py311h06a4308_0 - freetype=2.12.1=h4a9f257_0 - gmp=6.2.1=h295c915_3 - gmpy2=2.1.2=py311hc9b5ff0_0 - gnutls=3.6.15=he1e5248_0 - idna=3.7=py311h06a4308_0 - intel-openmp=2023.1.0=hdb19cb5_46306 - jinja2=3.1.4=py311h06a4308_0 - jpeg=9e=h5eee18b_3 - lame=3.100=h7b6447c_0 - lcms2=2.12=h3be6417_0 - ld_impl_linux-64=2.40=h12ee557_0 - lerc=3.0=h295c915_0 - libcublas=12.1.0.26=0 - libcufft=11.0.2.4=0 - libcufile=1.11.1.6=0 - libcurand=10.3.7.68=0 - libcusolver=11.4.4.55=0 - libcusparse=12.0.2.55=0 - libdeflate=1.17=h5eee18b_1 - libffi=3.4.4=h6a678d5_1 - libgcc-ng=11.2.0=h1234567_1 - libgomp=11.2.0=h1234567_1 - libiconv=1.16=h5eee18b_3 - libidn2=2.3.4=h5eee18b_0 - libjpeg-turbo=2.0.0=h9bf148f_0 - libnpp=12.0.2.50=0 - libnvjitlink=12.1.105=0 - libnvjpeg=12.1.1.14=0 - libpng=1.6.39=h5eee18b_0 - libstdcxx-ng=11.2.0=h1234567_1 - libtasn1=4.19.0=h5eee18b_0 - libtiff=4.5.1=h6a678d5_0 - libunistring=0.9.10=h27cfd23_0 - libuuid=1.41.5=h5eee18b_0 - libwebp-base=1.3.2=h5eee18b_0 - llvm-openmp=14.0.6=h9e868ea_0 - lz4-c=1.9.4=h6a678d5_1 - markupsafe=2.1.3=py311h5eee18b_0 - mkl=2023.1.0=h213fc3f_46344 - mkl-service=2.4.0=py311h5eee18b_1 - mkl_fft=1.3.10=py311h5eee18b_0 - mkl_random=1.2.7=py311ha02d727_0 - mpc=1.1.0=h10f8cd9_1 - mpfr=4.0.2=hb69a4c5_1 - mpmath=1.3.0=py311h06a4308_0 - ncurses=6.4=h6a678d5_0 - nettle=3.7.3=hbbd107a_1 - networkx=3.2.1=py311h06a4308_0 - numpy=2.0.1=py311h08b1b3b_1 - numpy-base=2.0.1=py311hf175353_1 - openh264=2.1.1=h4ff587b_0 - openjpeg=2.5.2=he7f1fd0_0 - openssl=3.0.15=h5eee18b_0 - pillow=10.4.0=py311h5eee18b_0 - pip=24.2=py311h06a4308_0 - pysocks=1.7.1=py311h06a4308_0 - python=3.11.9=h955ad1f_0 - pytorch=2.4.1=py3.11_cuda12.1_cudnn9.1.0_0 - pytorch-cuda=12.1=ha16c6d3_5 - pytorch-mutex=1.0=cuda - pyyaml=6.0.1=py311h5eee18b_0 - readline=8.2=h5eee18b_0 - requests=2.32.3=py311h06a4308_0 - setuptools=75.1.0=py311h06a4308_0 - sqlite=3.45.3=h5eee18b_0 - sympy=1.13.2=py311h06a4308_0 - tbb=2021.8.0=hdb19cb5_0 - tk=8.6.14=h39e8969_0 - torchtriton=3.0.0=py311 - torchvision=0.19.1=py311_cu121 - typing_extensions=4.11.0=py311h06a4308_0 - urllib3=2.2.2=py311h06a4308_0 - wheel=0.44.0=py311h06a4308_0 - xz=5.4.6=h5eee18b_1 - yaml=0.2.5=h7b6447c_0 - zlib=1.2.13=h5eee18b_1 - zstd=1.5.5=hc292b87_2 - pip: - accelerate==0.34.2 - contourpy==1.3.0 - cycler==0.12.1 - fonttools==4.54.1 - fsspec==2024.9.0 - huggingface-hub==0.25.1 - imageio==2.36.0 - importlib-metadata==8.5.0 - joblib==1.4.2 - kiwisolver==1.4.7 - lazy-loader==0.4 - lpips==0.1.4 - matplotlib==3.9.2 - opencv-python==4.10.0.84 - packaging==24.1 - pandas==2.2.3 - protobuf==5.28.2 - psutil==6.0.0 - pyparsing==3.1.4 - python-dateutil==2.9.0.post0 - pytz==2024.2 - regex==2024.9.11 - safetensors==0.4.5 - scikit-image==0.24.0 - scikit-learn==1.5.2 - scipy==1.14.1 - sentencepiece==0.2.0 - six==1.16.0 - threadpoolctl==3.5.0 - tifffile==2024.9.20 - tokenizers==0.20.0 - tqdm==4.66.5 - transformers==4.45.1 - tzdata==2024.2 - zipp==3.20.2 - -e . ================================================ FILE: run_stable_flow.py ================================================ import argparse import os import torch from diffusers import FluxPipeline import numpy as np from PIL import Image class StableFlow: MULTIMODAL_VITAL_LAYERS = [0, 1, 17, 18] SINGLE_MODAL_VITAL_LAYERS = list(np.array([28, 53, 54, 56, 25]) - 19) def __init__(self): self._parse_args() self._load_pipeline() def _parse_args(self): parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="black-forest-labs/FLUX.1-dev") parser.add_argument("--hf_token", type=str, required=True) parser.add_argument("--prompts", type=str, nargs="+", required=True) parser.add_argument("--output_path", type=str, default="outputs/result.jpg") parser.add_argument("--input_img_path", type=str, default=None) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--cpu_offload", action="store_true") parser.add_argument("--device", type=str, default="cuda") self.args = parser.parse_args() os.makedirs(os.path.dirname(self.args.output_path), exist_ok=True) def _load_pipeline(self): self.pipe = FluxPipeline.from_pretrained( self.args.model_path, torch_dtype=torch.float16, visualize_attention=False, token=self.args.hf_token ) if self.args.cpu_offload: self.pipe.enable_sequential_cpu_offload() else: self.pipe.to(self.args.device) @torch.no_grad() def infer_and_save(self, prompts): latents = torch.randn( (4096, 64), generator=torch.Generator(0).manual_seed(self.args.seed), device=self.args.device, dtype=torch.float16 ).tile(len(prompts), 1, 1) images = self.pipe( prompts, height=1024, width=1024, guidance_scale=3.5, output_type="pil", num_inference_steps=15, max_sequence_length=512, latents=latents, mm_copy_blocks=StableFlow.MULTIMODAL_VITAL_LAYERS, single_copy_blocks=StableFlow.SINGLE_MODAL_VITAL_LAYERS, ).images images = [np.array(img) for img in images] res = Image.fromarray(np.hstack((images))) res.save(self.args.output_path) @torch.no_grad() def image2latent(self, image, latent_nudging_scalar = 1.15): image = self.pipe.image_processor.preprocess(image).type(self.pipe.vae.dtype).to("cuda") latents = self.pipe.vae.encode(image)["latent_dist"].mean latents = (latents - self.pipe.vae.config.shift_factor) * self.pipe.vae.config.scaling_factor latents = latents * latent_nudging_scalar latents = self.pipe._pack_latents( latents=latents, batch_size=1, num_channels_latents=16, height=128, width=128 ) return latents @torch.no_grad() def invert_and_save(self, prompts): inversion_prompt = prompts[0:1] # Invert inverted_latent_list = self.pipe( inversion_prompt, height=1024, width=1024, guidance_scale=1, output_type="pil", num_inference_steps=50, max_sequence_length=512, latents=self.image2latent(Image.open(self.args.input_img_path)), invert_image=True ) # Edit images = self.pipe( prompts, height=1024, width=1024, guidance_scale=[1] + [3] * (len(prompts) - 1), output_type="pil", num_inference_steps=50, max_sequence_length=512, latents=inverted_latent_list[-1].tile(len(prompts), 1, 1), inverted_latent_list=inverted_latent_list, mm_copy_blocks=StableFlow.MULTIMODAL_VITAL_LAYERS, single_copy_blocks=StableFlow.SINGLE_MODAL_VITAL_LAYERS, ).images images = [np.array(img) for img in images] res = Image.fromarray(np.hstack((images))) res.save(self.args.output_path) if __name__ == "__main__": stable_flow = StableFlow() if stable_flow.args.input_img_path is None: stable_flow.infer_and_save(prompts=stable_flow.args.prompts) else: stable_flow.invert_and_save(prompts=stable_flow.args.prompts) ================================================ FILE: setup.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/main/setup.py To create the package for PyPI. 1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the documentation. If releasing on a special branch, copy the updated README.md on the main branch for the commit you will make for the post-release and run `make fix-copies` on the main branch as well. 2. Unpin specific versions from setup.py that use a git install. 3. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the message: "Release: " and push. 4. Manually trigger the "Nightly and release tests on main/release branch" workflow from the release branch. Wait for the tests to complete. We can safely ignore the known test failures. 5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs). 6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for PyPI'" Push the tag to git: git push --tags origin v-release 7. Build both the sources and the wheel. Do not change anything in setup.py between creating the wheel and the source distribution (obviously). For the wheel, run: "python setup.py bdist_wheel" in the top level directory (This will build a wheel for the Python version you use to build it). For the sources, run: "python setup.py sdist" You should now have a /dist directory with both .whl and .tar.gz source versions. Long story cut short, you need to run both before you can upload the distribution to the test PyPI and the actual PyPI servers: python setup.py bdist_wheel && python setup.py sdist 8. Check that everything looks correct by uploading the package to the PyPI test server: twine upload dist/* -r pypitest (pypi suggests using twine as other methods upload files via plaintext.) You may have to specify the repository url, use the following command then: twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ Check that you can install it in a virtualenv by running: pip install -i https://testpypi.python.org/pypi diffusers If you are testing from a Colab Notebook, for instance, then do: pip install diffusers && pip uninstall diffusers pip install -i https://testpypi.python.org/pypi diffusers Check you can run the following commands: python -c "from diffusers import __version__; print(__version__)" python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()" python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')" python -c "from diffusers import *" 9. Upload the final version to the actual PyPI: twine upload dist/* -r pypi 10. Prepare the release notes and publish them on GitHub once everything is looking hunky-dory. You can use the following Space to fetch all the commits applicable for the release: https://huggingface.co/spaces/lysandre/github-release. Repo should be `huggingface/diffusers`. `tag` should be the previous release tag (v0.26.1, for example), and `branch` should be the latest release branch (v0.27.0-release, for example). It denotes all commits that have happened on branch v0.27.0-release after the tag v0.26.1 was created. 11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release, you need to go back to main before executing this. """ import os import re import sys from setuptools import Command, find_packages, setup # IMPORTANT: # 1. all dependencies should be listed here with their version requirements if any # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py _deps = [ "Pillow", # keep the PIL.Image.Resampling deprecation away "accelerate>=0.31.0", "compel==0.1.8", "datasets", "filelock", "flax>=0.4.1", "hf-doc-builder>=0.3.0", "huggingface-hub>=0.23.2", "requests-mock==1.10.0", "importlib_metadata", "invisible-watermark>=0.2.0", "isort>=5.5.4", "jax>=0.4.1", "jaxlib>=0.4.1", "Jinja2", "k-diffusion>=0.0.12", "torchsde", "note_seq", "librosa", "numpy", "parameterized", "peft>=0.6.0", "protobuf>=3.20.3,<4", "pytest", "pytest-timeout", "pytest-xdist", "python>=3.8.0", "ruff==0.1.5", "safetensors>=0.3.1", "sentencepiece>=0.1.91,!=0.1.92", "GitPython<3.1.19", "scipy", "onnx", "regex!=2019.12.17", "requests", "tensorboard", "torch>=1.4", "torchvision", "transformers>=4.41.2", "urllib3<=2.0.0", "black", ] # this is a lookup table with items like: # # tokenizers: "huggingface-hub==0.8.0" # packaging: "packaging" # # some of the values are versioned whereas others aren't. deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)} # since we save this data in src/diffusers/dependency_versions_table.py it can be easily accessed from # anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with: # # python -c 'import sys; from diffusers.dependency_versions_table import deps; \ # print(" ".join([deps[x] for x in sys.argv[1:]]))' tokenizers datasets # # Just pass the desired package names to that script as it's shown with 2 packages above. # # If diffusers is not yet installed and the work is done from the cloned repo remember to add `PYTHONPATH=src` to the script above # # You can then feed this for example to `pip`: # # pip install -U $(python -c 'import sys; from diffusers.dependency_versions_table import deps; \ # print(" ".join([deps[x] for x in sys.argv[1:]]))' tokenizers datasets) # def deps_list(*pkgs): return [deps[pkg] for pkg in pkgs] class DepsTableUpdateCommand(Command): """ A custom command that updates the dependency table. usage: python setup.py deps_table_update """ description = "build runtime dependency table" user_options = [ # format: (long option, short option, description). ( "dep-table-update", None, "updates src/diffusers/dependency_versions_table.py", ), ] def initialize_options(self): pass def finalize_options(self): pass def run(self): entries = "\n".join([f' "{k}": "{v}",' for k, v in deps.items()]) content = [ "# THIS FILE HAS BEEN AUTOGENERATED. To update:", "# 1. modify the `_deps` dict in setup.py", "# 2. run `make deps_table_update`", "deps = {", entries, "}", "", ] target = "src/diffusers/dependency_versions_table.py" print(f"updating {target}") with open(target, "w", encoding="utf-8", newline="\n") as f: f.write("\n".join(content)) extras = {} extras["quality"] = deps_list("urllib3", "isort", "ruff", "hf-doc-builder") extras["docs"] = deps_list("hf-doc-builder") extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft") extras["test"] = deps_list( "compel", "GitPython", "datasets", "Jinja2", "invisible-watermark", "k-diffusion", "librosa", "parameterized", "pytest", "pytest-timeout", "pytest-xdist", "requests-mock", "safetensors", "sentencepiece", "scipy", "torchvision", "transformers", ) extras["torch"] = deps_list("torch", "accelerate") if os.name == "nt": # windows extras["flax"] = [] # jax is not supported on windows else: extras["flax"] = deps_list("jax", "jaxlib", "flax") extras["dev"] = ( extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] ) install_requires = [ deps["importlib_metadata"], deps["filelock"], deps["huggingface-hub"], deps["numpy"], deps["regex"], deps["requests"], deps["safetensors"], deps["Pillow"], ] version_range_max = max(sys.version_info[1], 10) + 1 setup( name="diffusers", version="0.30.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="State-of-the-art diffusion in PyTorch and JAX.", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", keywords="deep learning diffusion jax pytorch stable diffusion audioldm", license="Apache 2.0 License", author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/diffusers/graphs/contributors)", author_email="diffusers@huggingface.co", url="https://github.com/huggingface/diffusers", package_dir={"": "src"}, packages=find_packages("src"), package_data={"diffusers": ["py.typed"]}, include_package_data=True, python_requires=">=3.8.0", install_requires=list(install_requires), extras_require=extras, entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]}, classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Intended Audience :: Education", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Programming Language :: Python :: 3", ] + [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)], cmdclass={"deps_table_update": DepsTableUpdateCommand}, ) ================================================ FILE: src/diffusers/__init__.py ================================================ __version__ = "0.30.0" from typing import TYPE_CHECKING from .utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_k_diffusion_available, is_librosa_available, is_note_seq_available, is_onnx_available, is_scipy_available, is_sentencepiece_available, is_torch_available, is_torchsde_available, is_transformers_available, ) # Lazy Import based on # https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py # When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names, # and is used to defer the actual importing for when the objects are requested. # This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends). _import_structure = { "configuration_utils": ["ConfigMixin"], "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", "is_flax_available", "is_inflect_available", "is_invisible_watermark_available", "is_k_diffusion_available", "is_k_diffusion_version", "is_librosa_available", "is_note_seq_available", "is_onnx_available", "is_scipy_available", "is_torch_available", "is_torchsde_available", "is_transformers_available", "is_transformers_version", "is_unidecode_available", "logging", ], } try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_onnx_objects # noqa F403 _import_structure["utils.dummy_onnx_objects"] = [ name for name in dir(dummy_onnx_objects) if not name.startswith("_") ] else: _import_structure["pipelines"].extend(["OnnxRuntimeModel"]) try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_pt_objects # noqa F403 _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: _import_structure["models"].extend( [ "AsymmetricAutoencoderKL", "AuraFlowTransformer2DModel", "AutoencoderKL", "AutoencoderKLCogVideoX", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", "AutoencoderTiny", "CogVideoXTransformer3DModel", "ConsistencyDecoderVAE", "ControlNetModel", "ControlNetXSAdapter", "DiTTransformer2DModel", "FluxTransformer2DModel", "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", "I2VGenXLUNet", "Kandinsky3UNet", "LatteTransformer3DModel", "LuminaNextDiT2DModel", "ModelMixin", "MotionAdapter", "MultiAdapter", "PixArtTransformer2DModel", "PriorTransformer", "SD3ControlNetModel", "SD3MultiControlNetModel", "SD3Transformer2DModel", "SparseControlNetModel", "StableAudioDiTModel", "StableCascadeUNet", "T2IAdapter", "T5FilmDecoder", "Transformer2DModel", "UNet1DModel", "UNet2DConditionModel", "UNet2DModel", "UNet3DConditionModel", "UNetControlNetXSModel", "UNetMotionModel", "UNetSpatioTemporalConditionModel", "UVit2DModel", "VQModel", ] ) _import_structure["optimization"] = [ "get_constant_schedule", "get_constant_schedule_with_warmup", "get_cosine_schedule_with_warmup", "get_cosine_with_hard_restarts_schedule_with_warmup", "get_linear_schedule_with_warmup", "get_polynomial_decay_schedule_with_warmup", "get_scheduler", ] _import_structure["pipelines"].extend( [ "AudioPipelineOutput", "AutoPipelineForImage2Image", "AutoPipelineForInpainting", "AutoPipelineForText2Image", "ConsistencyModelPipeline", "DanceDiffusionPipeline", "DDIMPipeline", "DDPMPipeline", "DiffusionPipeline", "DiTPipeline", "ImagePipelineOutput", "KarrasVePipeline", "LDMPipeline", "LDMSuperResolutionPipeline", "PNDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", "StableDiffusionMixin", ] ) _import_structure["schedulers"].extend( [ "AmusedScheduler", "CMStochasticIterativeScheduler", "CogVideoXDDIMScheduler", "CogVideoXDPMScheduler", "DDIMInverseScheduler", "DDIMParallelScheduler", "DDIMScheduler", "DDPMParallelScheduler", "DDPMScheduler", "DDPMWuerstchenScheduler", "DEISMultistepScheduler", "DPMSolverMultistepInverseScheduler", "DPMSolverMultistepScheduler", "DPMSolverSinglestepScheduler", "EDMDPMSolverMultistepScheduler", "EDMEulerScheduler", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", "FlowMatchEulerDiscreteScheduler", "FlowMatchHeunDiscreteScheduler", "HeunDiscreteScheduler", "IPNDMScheduler", "KarrasVeScheduler", "KDPM2AncestralDiscreteScheduler", "KDPM2DiscreteScheduler", "LCMScheduler", "PNDMScheduler", "RePaintScheduler", "SASolverScheduler", "SchedulerMixin", "ScoreSdeVeScheduler", "TCDScheduler", "UnCLIPScheduler", "UniPCMultistepScheduler", "VQDiffusionScheduler", ] ) _import_structure["training_utils"] = ["EMAModel"] try: if not (is_torch_available() and is_scipy_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_torch_and_scipy_objects # noqa F403 _import_structure["utils.dummy_torch_and_scipy_objects"] = [ name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_") ] else: _import_structure["schedulers"].extend(["LMSDiscreteScheduler"]) try: if not (is_torch_available() and is_torchsde_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_torch_and_torchsde_objects # noqa F403 _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") ] else: _import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"]) try: if not (is_torch_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_torch_and_transformers_objects # noqa F403 _import_structure["utils.dummy_torch_and_transformers_objects"] = [ name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") ] else: _import_structure["pipelines"].extend( [ "AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline", "AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline", "AnimateDiffControlNetPipeline", "AnimateDiffPAGPipeline", "AnimateDiffPipeline", "AnimateDiffSDXLPipeline", "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoPipeline", "AudioLDM2Pipeline", "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", "AudioLDMPipeline", "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", "CLIPImageProjection", "CogVideoXPipeline", "CycleDiffusionPipeline", "FluxPipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", "IFInpaintingPipeline", "IFInpaintingSuperResolutionPipeline", "IFPipeline", "IFSuperResolutionPipeline", "ImageTextPipelineOutput", "Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline", "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", "KandinskyImg2ImgPipeline", "KandinskyInpaintCombinedPipeline", "KandinskyInpaintPipeline", "KandinskyPipeline", "KandinskyPriorPipeline", "KandinskyV22CombinedPipeline", "KandinskyV22ControlnetImg2ImgPipeline", "KandinskyV22ControlnetPipeline", "KandinskyV22Img2ImgCombinedPipeline", "KandinskyV22Img2ImgPipeline", "KandinskyV22InpaintCombinedPipeline", "KandinskyV22InpaintPipeline", "KandinskyV22Pipeline", "KandinskyV22PriorEmb2EmbPipeline", "KandinskyV22PriorPipeline", "LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelPipeline", "LattePipeline", "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", "LuminaText2ImgPipeline", "MarigoldDepthPipeline", "MarigoldNormalsPipeline", "MusicLDMPipeline", "PaintByExamplePipeline", "PIAPipeline", "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", "StableAudioPipeline", "StableAudioProjectionModel", "StableCascadeCombinedPipeline", "StableCascadeDecoderPipeline", "StableCascadePriorPipeline", "StableDiffusion3ControlNetPipeline", "StableDiffusion3Img2ImgPipeline", "StableDiffusion3InpaintPipeline", "StableDiffusion3PAGPipeline", "StableDiffusion3Pipeline", "StableDiffusionAdapterPipeline", "StableDiffusionAttendAndExcitePipeline", "StableDiffusionControlNetImg2ImgPipeline", "StableDiffusionControlNetInpaintPipeline", "StableDiffusionControlNetPAGPipeline", "StableDiffusionControlNetPipeline", "StableDiffusionControlNetXSPipeline", "StableDiffusionDepth2ImgPipeline", "StableDiffusionDiffEditPipeline", "StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENTextImagePipeline", "StableDiffusionImageVariationPipeline", "StableDiffusionImg2ImgPipeline", "StableDiffusionInpaintPipeline", "StableDiffusionInpaintPipelineLegacy", "StableDiffusionInstructPix2PixPipeline", "StableDiffusionLatentUpscalePipeline", "StableDiffusionLDM3DPipeline", "StableDiffusionModelEditingPipeline", "StableDiffusionPAGPipeline", "StableDiffusionPanoramaPipeline", "StableDiffusionParadigmsPipeline", "StableDiffusionPipeline", "StableDiffusionPipelineSafe", "StableDiffusionPix2PixZeroPipeline", "StableDiffusionSAGPipeline", "StableDiffusionUpscalePipeline", "StableDiffusionXLAdapterPipeline", "StableDiffusionXLControlNetImg2ImgPipeline", "StableDiffusionXLControlNetInpaintPipeline", "StableDiffusionXLControlNetPAGPipeline", "StableDiffusionXLControlNetPipeline", "StableDiffusionXLControlNetXSPipeline", "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", "StableVideoDiffusionPipeline", "TextToVideoSDPipeline", "TextToVideoZeroPipeline", "TextToVideoZeroSDXLPipeline", "UnCLIPImageVariationPipeline", "UnCLIPPipeline", "UniDiffuserModel", "UniDiffuserPipeline", "UniDiffuserTextDecoder", "VersatileDiffusionDualGuidedPipeline", "VersatileDiffusionImageVariationPipeline", "VersatileDiffusionPipeline", "VersatileDiffusionTextToImagePipeline", "VideoToVideoSDPipeline", "VQDiffusionPipeline", "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", ] ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") ] else: _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"]) try: if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403 _import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [ name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_") ] else: _import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"]) try: if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") ] else: _import_structure["pipelines"].extend( [ "OnnxStableDiffusionImg2ImgPipeline", "OnnxStableDiffusionInpaintPipeline", "OnnxStableDiffusionInpaintPipelineLegacy", "OnnxStableDiffusionPipeline", "OnnxStableDiffusionUpscalePipeline", "StableDiffusionOnnxPipeline", ] ) try: if not (is_torch_available() and is_librosa_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_torch_and_librosa_objects # noqa F403 _import_structure["utils.dummy_torch_and_librosa_objects"] = [ name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") ] else: _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") ] else: _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"]) try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_flax_objects # noqa F403 _import_structure["utils.dummy_flax_objects"] = [ name for name in dir(dummy_flax_objects) if not name.startswith("_") ] else: _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["schedulers"].extend( [ "FlaxDDIMScheduler", "FlaxDDPMScheduler", "FlaxDPMSolverMultistepScheduler", "FlaxEulerDiscreteScheduler", "FlaxKarrasVeScheduler", "FlaxLMSDiscreteScheduler", "FlaxPNDMScheduler", "FlaxSchedulerMixin", "FlaxScoreSdeVeScheduler", ] ) try: if not (is_flax_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_flax_and_transformers_objects # noqa F403 _import_structure["utils.dummy_flax_and_transformers_objects"] = [ name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") ] else: _import_structure["pipelines"].extend( [ "FlaxStableDiffusionControlNetPipeline", "FlaxStableDiffusionImg2ImgPipeline", "FlaxStableDiffusionInpaintPipeline", "FlaxStableDiffusionPipeline", "FlaxStableDiffusionXLPipeline", ] ) try: if not (is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_note_seq_objects # noqa F403 _import_structure["utils.dummy_note_seq_objects"] = [ name for name in dir(dummy_note_seq_objects) if not name.startswith("_") ] else: _import_structure["pipelines"].extend(["MidiProcessor"]) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_onnx_objects import * # noqa F403 else: from .pipelines import OnnxRuntimeModel try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: from .models import ( AsymmetricAutoencoderKL, AuraFlowTransformer2DModel, AutoencoderKL, AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, CogVideoXTransformer3DModel, ConsistencyDecoderVAE, ControlNetModel, ControlNetXSAdapter, DiTTransformer2DModel, FluxTransformer2DModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, I2VGenXLUNet, Kandinsky3UNet, LatteTransformer3DModel, LuminaNextDiT2DModel, ModelMixin, MotionAdapter, MultiAdapter, PixArtTransformer2DModel, PriorTransformer, SD3ControlNetModel, SD3MultiControlNetModel, SD3Transformer2DModel, SparseControlNetModel, StableAudioDiTModel, T2IAdapter, T5FilmDecoder, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, UNet3DConditionModel, UNetControlNetXSModel, UNetMotionModel, UNetSpatioTemporalConditionModel, UVit2DModel, VQModel, ) from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, get_scheduler, ) from .pipelines import ( AudioPipelineOutput, AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image, BlipDiffusionControlNetPipeline, BlipDiffusionPipeline, CLIPImageProjection, ConsistencyModelPipeline, DanceDiffusionPipeline, DDIMPipeline, DDPMPipeline, DiffusionPipeline, DiTPipeline, ImagePipelineOutput, KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, StableDiffusionMixin, ) from .schedulers import ( AmusedScheduler, CMStochasticIterativeScheduler, CogVideoXDDIMScheduler, CogVideoXDPMScheduler, DDIMInverseScheduler, DDIMParallelScheduler, DDIMScheduler, DDPMParallelScheduler, DDPMScheduler, DDPMWuerstchenScheduler, DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EDMDPMSolverMultistepScheduler, EDMEulerScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, HeunDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LCMScheduler, PNDMScheduler, RePaintScheduler, SASolverScheduler, SchedulerMixin, ScoreSdeVeScheduler, TCDScheduler, UnCLIPScheduler, UniPCMultistepScheduler, VQDiffusionScheduler, ) from .training_utils import EMAModel try: if not (is_torch_available() and is_scipy_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_scipy_objects import * # noqa F403 else: from .schedulers import LMSDiscreteScheduler try: if not (is_torch_available() and is_torchsde_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_torchsde_objects import * # noqa F403 else: from .schedulers import CosineDPMSolverMultistepScheduler, DPMSolverSDEScheduler try: if not (is_torch_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipelines import ( AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline, AnimateDiffControlNetPipeline, AnimateDiffPAGPipeline, AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffSparseControlNetPipeline, AnimateDiffVideoToVideoPipeline, AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, CLIPImageProjection, CogVideoXPipeline, CycleDiffusionPipeline, FluxPipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, IFPipeline, IFSuperResolutionPipeline, ImageTextPipelineOutput, Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyImg2ImgPipeline, KandinskyInpaintCombinedPipeline, KandinskyInpaintPipeline, KandinskyPipeline, KandinskyPriorPipeline, KandinskyV22CombinedPipeline, KandinskyV22ControlnetImg2ImgPipeline, KandinskyV22ControlnetPipeline, KandinskyV22Img2ImgCombinedPipeline, KandinskyV22Img2ImgPipeline, KandinskyV22InpaintCombinedPipeline, KandinskyV22InpaintPipeline, KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorPipeline, LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, LattePipeline, LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, LuminaText2ImgPipeline, MarigoldDepthPipeline, MarigoldNormalsPipeline, MusicLDMPipeline, PaintByExamplePipeline, PIAPipeline, PixArtAlphaPipeline, PixArtSigmaPAGPipeline, PixArtSigmaPipeline, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, StableAudioPipeline, StableAudioProjectionModel, StableCascadeCombinedPipeline, StableCascadeDecoderPipeline, StableCascadePriorPipeline, StableDiffusion3ControlNetPipeline, StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline, StableDiffusion3PAGPipeline, StableDiffusion3Pipeline, StableDiffusionAdapterPipeline, StableDiffusionAttendAndExcitePipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionControlNetPipeline, StableDiffusionControlNetXSPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionInstructPix2PixPipeline, StableDiffusionLatentUpscalePipeline, StableDiffusionLDM3DPipeline, StableDiffusionModelEditingPipeline, StableDiffusionPAGPipeline, StableDiffusionPanoramaPipeline, StableDiffusionParadigmsPipeline, StableDiffusionPipeline, StableDiffusionPipelineSafe, StableDiffusionPix2PixZeroPipeline, StableDiffusionSAGPipeline, StableDiffusionUpscalePipeline, StableDiffusionXLAdapterPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPAGPipeline, StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetXSPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, StableVideoDiffusionPipeline, TextToVideoSDPipeline, TextToVideoZeroPipeline, TextToVideoZeroSDXLPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder, VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, VideoToVideoSDPipeline, VQDiffusionPipeline, WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 else: from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline try: if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403 else: from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline try: if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 else: from .pipelines import ( OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionPipeline, OnnxStableDiffusionUpscalePipeline, StableDiffusionOnnxPipeline, ) try: if not (is_torch_available() and is_librosa_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_librosa_objects import * # noqa F403 else: from .pipelines import AudioDiffusionPipeline, Mel try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 else: from .pipelines import SpectrogramDiffusionPipeline try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_flax_objects import * # noqa F403 else: from .models.controlnet_flax import FlaxControlNetModel from .models.modeling_flax_utils import FlaxModelMixin from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxDPMSolverMultistepScheduler, FlaxEulerDiscreteScheduler, FlaxKarrasVeScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, FlaxSchedulerMixin, FlaxScoreSdeVeScheduler, ) try: if not (is_flax_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_flax_and_transformers_objects import * # noqa F403 else: from .pipelines import ( FlaxStableDiffusionControlNetPipeline, FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, FlaxStableDiffusionXLPipeline, ) try: if not (is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_note_seq_objects import * # noqa F403 else: from .pipelines import MidiProcessor else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, extra_objects={"__version__": __version__}, ) ================================================ FILE: src/diffusers/callbacks.py ================================================ from typing import Any, Dict, List from .configuration_utils import ConfigMixin, register_to_config from .utils import CONFIG_NAME class PipelineCallback(ConfigMixin): """ Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing custom callbacks and ensures that all callbacks have a consistent interface. Please implement the following: `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. `callback_fn`: This method defines the core functionality of your callback. """ config_name = CONFIG_NAME @register_to_config def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None): super().__init__() if (cutoff_step_ratio is None and cutoff_step_index is None) or ( cutoff_step_ratio is not None and cutoff_step_index is not None ): raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.") if cutoff_step_ratio is not None and ( not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0) ): raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.") @property def tensor_inputs(self) -> List[str]: raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}") def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]: raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}") def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: return self.callback_fn(pipeline, step_index, timestep, callback_kwargs) class MultiPipelineCallbacks: """ This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and provides a unified interface for calling all of them. """ def __init__(self, callbacks: List[PipelineCallback]): self.callbacks = callbacks @property def tensor_inputs(self) -> List[str]: return [input for callback in self.callbacks for input in callback.tensor_inputs] def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: """ Calls all the callbacks in order with the given arguments and returns the final callback_kwargs. """ for callback in self.callbacks: callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs) return callback_kwargs class SDCFGCutoffCallback(PipelineCallback): """ Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG. Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. """ tensor_inputs = ["prompt_embeds"] def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: cutoff_step_ratio = self.config.cutoff_step_ratio cutoff_step_index = self.config.cutoff_step_index # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio cutoff_step = ( cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) ) if step_index == cutoff_step: prompt_embeds = callback_kwargs[self.tensor_inputs[0]] prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. pipeline._guidance_scale = 0.0 callback_kwargs[self.tensor_inputs[0]] = prompt_embeds return callback_kwargs class SDXLCFGCutoffCallback(PipelineCallback): """ Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG. Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. """ tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"] def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: cutoff_step_ratio = self.config.cutoff_step_ratio cutoff_step_index = self.config.cutoff_step_index # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio cutoff_step = ( cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) ) if step_index == cutoff_step: prompt_embeds = callback_kwargs[self.tensor_inputs[0]] prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. add_text_embeds = callback_kwargs[self.tensor_inputs[1]] add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens add_time_ids = callback_kwargs[self.tensor_inputs[2]] add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector pipeline._guidance_scale = 0.0 callback_kwargs[self.tensor_inputs[0]] = prompt_embeds callback_kwargs[self.tensor_inputs[1]] = add_text_embeds callback_kwargs[self.tensor_inputs[2]] = add_time_ids return callback_kwargs class IPAdapterScaleCutoffCallback(PipelineCallback): """ Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`. Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step. """ tensor_inputs = [] def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: cutoff_step_ratio = self.config.cutoff_step_ratio cutoff_step_index = self.config.cutoff_step_index # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio cutoff_step = ( cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) ) if step_index == cutoff_step: pipeline.set_ip_adapter_scale(0.0) return callback_kwargs ================================================ FILE: src/diffusers/commands/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from argparse import ArgumentParser class BaseDiffusersCLICommand(ABC): @staticmethod @abstractmethod def register_subcommand(parser: ArgumentParser): raise NotImplementedError() @abstractmethod def run(self): raise NotImplementedError() ================================================ FILE: src/diffusers/commands/diffusers_cli.py ================================================ #!/usr/bin/env python # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from argparse import ArgumentParser from .env import EnvironmentCommand from .fp16_safetensors import FP16SafetensorsCommand def main(): parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") # Register commands EnvironmentCommand.register_subcommand(commands_parser) FP16SafetensorsCommand.register_subcommand(commands_parser) # Let's go args = parser.parse_args() if not hasattr(args, "func"): parser.print_help() exit(1) # Run service = args.func(args) service.run() if __name__ == "__main__": main() ================================================ FILE: src/diffusers/commands/env.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import platform import subprocess from argparse import ArgumentParser import huggingface_hub from .. import __version__ as version from ..utils import ( is_accelerate_available, is_bitsandbytes_available, is_flax_available, is_google_colab, is_peft_available, is_safetensors_available, is_torch_available, is_transformers_available, is_xformers_available, ) from . import BaseDiffusersCLICommand def info_command_factory(_): return EnvironmentCommand() class EnvironmentCommand(BaseDiffusersCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser) -> None: download_parser = parser.add_parser("env") download_parser.set_defaults(func=info_command_factory) def run(self) -> dict: hub_version = huggingface_hub.__version__ safetensors_version = "not installed" if is_safetensors_available(): import safetensors safetensors_version = safetensors.__version__ pt_version = "not installed" pt_cuda_available = "NA" if is_torch_available(): import torch pt_version = torch.__version__ pt_cuda_available = torch.cuda.is_available() flax_version = "not installed" jax_version = "not installed" jaxlib_version = "not installed" jax_backend = "NA" if is_flax_available(): import flax import jax import jaxlib flax_version = flax.__version__ jax_version = jax.__version__ jaxlib_version = jaxlib.__version__ jax_backend = jax.lib.xla_bridge.get_backend().platform transformers_version = "not installed" if is_transformers_available(): import transformers transformers_version = transformers.__version__ accelerate_version = "not installed" if is_accelerate_available(): import accelerate accelerate_version = accelerate.__version__ peft_version = "not installed" if is_peft_available(): import peft peft_version = peft.__version__ bitsandbytes_version = "not installed" if is_bitsandbytes_available(): import bitsandbytes bitsandbytes_version = bitsandbytes.__version__ xformers_version = "not installed" if is_xformers_available(): import xformers xformers_version = xformers.__version__ platform_info = platform.platform() is_google_colab_str = "Yes" if is_google_colab() else "No" accelerator = "NA" if platform.system() in {"Linux", "Windows"}: try: sp = subprocess.Popen( ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) out_str, _ = sp.communicate() out_str = out_str.decode("utf-8") if len(out_str) > 0: accelerator = out_str.strip() except FileNotFoundError: pass elif platform.system() == "Darwin": # Mac OS try: sp = subprocess.Popen( ["system_profiler", "SPDisplaysDataType"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) out_str, _ = sp.communicate() out_str = out_str.decode("utf-8") start = out_str.find("Chipset Model:") if start != -1: start += len("Chipset Model:") end = out_str.find("\n", start) accelerator = out_str[start:end].strip() start = out_str.find("VRAM (Total):") if start != -1: start += len("VRAM (Total):") end = out_str.find("\n", start) accelerator += " VRAM: " + out_str[start:end].strip() except FileNotFoundError: pass else: print("It seems you are running an unusual OS. Could you fill in the accelerator manually?") info = { "🤗 Diffusers version": version, "Platform": platform_info, "Running on Google Colab?": is_google_colab_str, "Python version": platform.python_version(), "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})", "Jax version": jax_version, "JaxLib version": jaxlib_version, "Huggingface_hub version": hub_version, "Transformers version": transformers_version, "Accelerate version": accelerate_version, "PEFT version": peft_version, "Bitsandbytes version": bitsandbytes_version, "Safetensors version": safetensors_version, "xFormers version": xformers_version, "Accelerator": accelerator, "Using GPU in script?": "", "Using distributed or parallel set-up in script?": "", } print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") print(self.format_dict(info)) return info @staticmethod def format_dict(d: dict) -> str: return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" ================================================ FILE: src/diffusers/commands/fp16_safetensors.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Usage example: diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors """ import glob import json import warnings from argparse import ArgumentParser, Namespace from importlib import import_module import huggingface_hub import torch from huggingface_hub import hf_hub_download from packaging import version from ..utils import logging from . import BaseDiffusersCLICommand def conversion_command_factory(args: Namespace): if args.use_auth_token: warnings.warn( "The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now" " handled automatically if user is logged in." ) return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors) class FP16SafetensorsCommand(BaseDiffusersCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser): conversion_parser = parser.add_parser("fp16_safetensors") conversion_parser.add_argument( "--ckpt_id", type=str, help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.", ) conversion_parser.add_argument( "--fp16", action="store_true", help="If serializing the variables in FP16 precision." ) conversion_parser.add_argument( "--use_safetensors", action="store_true", help="If serializing in the safetensors format." ) conversion_parser.add_argument( "--use_auth_token", action="store_true", help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.", ) conversion_parser.set_defaults(func=conversion_command_factory) def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool): self.logger = logging.get_logger("diffusers-cli/fp16_safetensors") self.ckpt_id = ckpt_id self.local_ckpt_dir = f"/tmp/{ckpt_id}" self.fp16 = fp16 self.use_safetensors = use_safetensors if not self.use_safetensors and not self.fp16: raise NotImplementedError( "When `use_safetensors` and `fp16` both are False, then this command is of no use." ) def run(self): if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"): raise ImportError( "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub" " installation." ) else: from huggingface_hub import create_commit from huggingface_hub._commit_api import CommitOperationAdd model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json") with open(model_index, "r") as f: pipeline_class_name = json.load(f)["_class_name"] pipeline_class = getattr(import_module("diffusers"), pipeline_class_name) self.logger.info(f"Pipeline class imported: {pipeline_class_name}.") # Load the appropriate pipeline. We could have use `DiffusionPipeline` # here, but just to avoid any rough edge cases. pipeline = pipeline_class.from_pretrained( self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32 ) pipeline.save_pretrained( self.local_ckpt_dir, safe_serialization=True if self.use_safetensors else False, variant="fp16" if self.fp16 else None, ) self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.") # Fetch all the paths. if self.fp16: modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*") elif self.use_safetensors: modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors") # Prepare for the PR. commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}." operations = [] for path in modified_paths: operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path)) # Open the PR. commit_description = ( "Variables converted by the [`diffusers`' `fp16_safetensors`" " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)." ) hub_pr_url = create_commit( repo_id=self.ckpt_id, operations=operations, commit_message=commit_message, commit_description=commit_description, repo_type="model", create_pr=True, ).pr_url self.logger.info(f"PR created here: {hub_pr_url}.") ================================================ FILE: src/diffusers/configuration_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ConfigMixin base class and utilities.""" import dataclasses import functools import importlib import inspect import json import os import re from collections import OrderedDict from pathlib import Path from typing import Any, Dict, Tuple, Union import numpy as np from huggingface_hub import create_repo, hf_hub_download from huggingface_hub.utils import ( EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, validate_hf_hub_args, ) from requests import HTTPError from . import __version__ from .utils import ( HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, extract_commit_hash, http_user_agent, logging, ) logger = logging.get_logger(__name__) _re_configuration_file = re.compile(r"config\.(.*)\.json") class FrozenDict(OrderedDict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for key, value in self.items(): setattr(self, key, value) self.__frozen = True def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") def setdefault(self, *args, **kwargs): raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") def pop(self, *args, **kwargs): raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") def update(self, *args, **kwargs): raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") def __setattr__(self, name, value): if hasattr(self, "__frozen") and self.__frozen: raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") super().__setattr__(name, value) def __setitem__(self, name, value): if hasattr(self, "__frozen") and self.__frozen: raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") super().__setitem__(name, value) class ConfigMixin: r""" Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and saving classes that inherit from [`ConfigMixin`]. Class attributes: - **config_name** (`str`) -- A filename under which the config should stored when calling [`~ConfigMixin.save_config`] (should be overridden by parent class). - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be overridden by subclass). - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by subclass). """ config_name = None ignore_for_config = [] has_compatibles = False _deprecated_kwargs = [] def register_to_config(self, **kwargs): if self.config_name is None: raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") # Special case for `kwargs` used in deprecation warning added to schedulers # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # or solve in a more general way. kwargs.pop("kwargs", None) if not hasattr(self, "_internal_dict"): internal_dict = kwargs else: previous_dict = dict(self._internal_dict) internal_dict = {**self._internal_dict, **kwargs} logger.debug(f"Updating config from {previous_dict} to {internal_dict}") self._internal_dict = FrozenDict(internal_dict) def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 This function is mostly copied from PyTorch's __getattr__ overwrite: https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module """ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) is_attribute = name in self.__dict__ if is_in_config and not is_attribute: deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) return self._internal_dict[name] raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the [`~ConfigMixin.from_config`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file is saved (will be created if it does not exist). push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") os.makedirs(save_directory, exist_ok=True) # If we save using the predefined names, we can load using `from_config` output_config_file = os.path.join(save_directory, self.config_name) self.to_json_file(output_config_file) logger.info(f"Configuration saved in {output_config_file}") if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", False) create_pr = kwargs.pop("create_pr", False) token = kwargs.pop("token", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id self._upload_folder( save_directory, repo_id, token=token, commit_message=commit_message, create_pr=create_pr, ) @classmethod def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): r""" Instantiate a Python class from a config dictionary. Parameters: config (`Dict[str, Any]`): A config dictionary from which the Python class is instantiated. Make sure to only load configuration files of compatible classes. return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether kwargs that are not consumed by the Python class should be returned or not. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it is loaded) and initiate the Python class. `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually overwrite the same named arguments in `config`. Returns: [`ModelMixin`] or [`SchedulerMixin`]: A model or scheduler object instantiated from a config dictionary. Examples: ```python >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler >>> # Download scheduler from huggingface.co and cache. >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") >>> # Instantiate DDIM scheduler class with same config as DDPM >>> scheduler = DDIMScheduler.from_config(scheduler.config) >>> # Instantiate PNDM scheduler class with same config as DDPM >>> scheduler = PNDMScheduler.from_config(scheduler.config) ``` """ # <===== TO BE REMOVED WITH DEPRECATION # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated if "pretrained_model_name_or_path" in kwargs: config = kwargs.pop("pretrained_model_name_or_path") if config is None: raise ValueError("Please make sure to provide a config as the first positional argument.") # ======> if not isinstance(config, dict): deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." if "Scheduler" in cls.__name__: deprecation_message += ( f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" " be removed in v1.0.0." ) elif "Model" in cls.__name__: deprecation_message += ( f"If you were trying to load a model, please use {cls}.load_config(...) followed by" f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" " instead. This functionality will be removed in v1.0.0." ) deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) # Allow dtype to be specified on initialization if "dtype" in unused_kwargs: init_dict["dtype"] = unused_kwargs.pop("dtype") # add possible deprecated kwargs for deprecated_kwarg in cls._deprecated_kwargs: if deprecated_kwarg in unused_kwargs: init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) # Return model and optionally state and/or unused_kwargs model = cls(**init_dict) # make sure to also save config parameters that might be used for compatible classes # update _class_name if "_class_name" in hidden_dict: hidden_dict["_class_name"] = cls.__name__ model.register_to_config(**hidden_dict) # add hidden kwargs of compatible classes to unused_kwargs unused_kwargs = {**unused_kwargs, **hidden_dict} if return_unused_kwargs: return (model, unused_kwargs) else: return model @classmethod def get_config_dict(cls, *args, **kwargs): deprecation_message = ( f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" " removed in version v1.0.0" ) deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) return cls.load_config(*args, **kwargs) @classmethod @validate_hf_hub_args def load_config( cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, return_commit_hash=False, **kwargs, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: r""" Load a model or scheduler configuration. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with [`~ConfigMixin.save_config`]. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. return_unused_kwargs (`bool`, *optional*, defaults to `False): Whether unused keyword arguments of the config are returned. return_commit_hash (`bool`, *optional*, defaults to `False): Whether the `commit_hash` of the loaded configuration are returned. Returns: `dict`: A dictionary of all the parameters stored in a JSON configuration file. """ cache_dir = kwargs.pop("cache_dir", None) local_dir = kwargs.pop("local_dir", None) local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto") force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) token = kwargs.pop("token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) _ = kwargs.pop("mirror", None) subfolder = kwargs.pop("subfolder", None) user_agent = kwargs.pop("user_agent", {}) user_agent = {**user_agent, "file_type": "config"} user_agent = http_user_agent(user_agent) pretrained_model_name_or_path = str(pretrained_model_name_or_path) if cls.config_name is None: raise ValueError( "`self.config_name` is not defined. Note that one should not load a config from " "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" ) if os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): if subfolder is not None and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) ): config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): # Load from a PyTorch checkpoint config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) else: raise EnvironmentError( f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." ) else: try: # Load from URL or cache if already cached config_file = hf_hub_download( pretrained_model_name_or_path, filename=cls.config_name, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, user_agent=user_agent, subfolder=subfolder, revision=revision, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, ) except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" " token having permission to this repo with `token` or log in with `huggingface-cli login`." ) except RevisionNotFoundError: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" " this model name. Check the model page at" f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." ) except HTTPError as err: raise EnvironmentError( "There was a specific connection error when trying to load" f" {pretrained_model_name_or_path}:\n{err}" ) except ValueError: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" " run the library in offline mode at" " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a {cls.config_name} file" ) try: # Load config dict config_dict = cls._dict_from_json_file(config_file) commit_hash = extract_commit_hash(config_file) except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") if not (return_unused_kwargs or return_commit_hash): return config_dict outputs = (config_dict,) if return_unused_kwargs: outputs += (kwargs,) if return_commit_hash: outputs += (commit_hash,) return outputs @staticmethod def _get_init_keys(input_class): return set(dict(inspect.signature(input_class.__init__).parameters).keys()) @classmethod def extract_init_dict(cls, config_dict, **kwargs): # Skip keys that were not present in the original config, so default __init__ values were used used_defaults = config_dict.get("_use_default_values", []) config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} # 0. Copy origin config dict original_dict = dict(config_dict.items()) # 1. Retrieve expected config attributes from __init__ signature expected_keys = cls._get_init_keys(cls) expected_keys.remove("self") # remove general kwargs if present in dict if "kwargs" in expected_keys: expected_keys.remove("kwargs") # remove flax internal keys if hasattr(cls, "_flax_internal_args"): for arg in cls._flax_internal_args: expected_keys.remove(arg) # 2. Remove attributes that cannot be expected from expected config attributes # remove keys to be ignored if len(cls.ignore_for_config) > 0: expected_keys = expected_keys - set(cls.ignore_for_config) # load diffusers library to import compatible and original scheduler diffusers_library = importlib.import_module(__name__.split(".")[0]) if cls.has_compatibles: compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] else: compatible_classes = [] expected_keys_comp_cls = set() for c in compatible_classes: expected_keys_c = cls._get_init_keys(c) expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls} # remove attributes from orig class that cannot be expected orig_cls_name = config_dict.pop("_class_name", cls.__name__) if ( isinstance(orig_cls_name, str) and orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name) ): orig_cls = getattr(diffusers_library, orig_cls_name) unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)): raise ValueError( "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)." ) # remove private attributes config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments init_dict = {} for key in expected_keys: # if config param is passed to kwarg and is present in config dict # it should overwrite existing config dict key if key in kwargs and key in config_dict: config_dict[key] = kwargs.pop(key) if key in kwargs: # overwrite key init_dict[key] = kwargs.pop(key) elif key in config_dict: # use value from config dict init_dict[key] = config_dict.pop(key) # 4. Give nice warning if unexpected values have been passed if len(config_dict) > 0: logger.warning( f"The config attributes {config_dict} were passed to {cls.__name__}, " "but are not expected and will be ignored. Please verify your " f"{cls.config_name} configuration file." ) # 5. Give nice info if config attributes are initialized to default because they have not been passed passed_keys = set(init_dict.keys()) if len(expected_keys - passed_keys) > 0: logger.info( f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." ) # 6. Define unused keyword arguments unused_kwargs = {**config_dict, **kwargs} # 7. Define "hidden" config parameters that were saved for compatible classes hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} return init_dict, unused_kwargs, hidden_config_dict @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() return json.loads(text) def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" @property def config(self) -> Dict[str, Any]: """ Returns the config of the class as a frozen dictionary Returns: `Dict[str, Any]`: Config of the class. """ return self._internal_dict def to_json_string(self) -> str: """ Serializes the configuration instance to a JSON string. Returns: `str`: String containing all the attributes that make up the configuration instance in JSON format. """ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} config_dict["_class_name"] = self.__class__.__name__ config_dict["_diffusers_version"] = __version__ def to_json_saveable(value): if isinstance(value, np.ndarray): value = value.tolist() elif isinstance(value, Path): value = value.as_posix() return value config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} # Don't save "_ignore_files" or "_use_default_values" config_dict.pop("_ignore_files", None) config_dict.pop("_use_default_values", None) return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike]): """ Save the configuration instance's parameters to a JSON file. Args: json_file_path (`str` or `os.PathLike`): Path to the JSON file to save a configuration instance's parameters. """ with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(self.to_json_string()) def register_to_config(init): r""" Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be registered in the config, use the `ignore_for_config` class variable Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! """ @functools.wraps(init) def inner_init(self, *args, **kwargs): # Ignore private kwargs in the init. init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} if not isinstance(self, ConfigMixin): raise RuntimeError( f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " "not inherit from `ConfigMixin`." ) ignore = getattr(self, "ignore_for_config", []) # Get positional arguments aligned with kwargs new_kwargs = {} signature = inspect.signature(init) parameters = { name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore } for arg, name in zip(args, parameters.keys()): new_kwargs[name] = arg # Then add all kwargs new_kwargs.update( { k: init_kwargs.get(k, default) for k, default in parameters.items() if k not in ignore and k not in new_kwargs } ) # Take note of the parameters that were not present in the loaded config if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) new_kwargs = {**config_init_kwargs, **new_kwargs} getattr(self, "register_to_config")(**new_kwargs) init(self, *args, **init_kwargs) return inner_init def flax_register_to_config(cls): original_init = cls.__init__ @functools.wraps(original_init) def init(self, *args, **kwargs): if not isinstance(self, ConfigMixin): raise RuntimeError( f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " "not inherit from `ConfigMixin`." ) # Ignore private kwargs in the init. Retrieve all passed attributes init_kwargs = dict(kwargs.items()) # Retrieve default values fields = dataclasses.fields(self) default_kwargs = {} for field in fields: # ignore flax specific attributes if field.name in self._flax_internal_args: continue if type(field.default) == dataclasses._MISSING_TYPE: default_kwargs[field.name] = None else: default_kwargs[field.name] = getattr(self, field.name) # Make sure init_kwargs override default kwargs new_kwargs = {**default_kwargs, **init_kwargs} # dtype should be part of `init_kwargs`, but not `new_kwargs` if "dtype" in new_kwargs: new_kwargs.pop("dtype") # Get positional arguments aligned with kwargs for i, arg in enumerate(args): name = fields[i].name new_kwargs[name] = arg # Take note of the parameters that were not present in the loaded config if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) getattr(self, "register_to_config")(**new_kwargs) original_init(self, *args, **kwargs) cls.__init__ = init return cls class LegacyConfigMixin(ConfigMixin): r""" A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more pipeline-specific classes (like `DiTTransformer2DModel`). """ @classmethod def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): # To prevent dependency import problem. from .models.model_loading_utils import _fetch_remapped_cls_from_config # resolve remapping remapped_class = _fetch_remapped_cls_from_config(config, cls) return remapped_class.from_config(config, return_unused_kwargs, **kwargs) ================================================ FILE: src/diffusers/dependency_versions_check.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .dependency_versions_table import deps from .utils.versions import require_version, require_version_core # define which module versions we always want to check at run time # (usually the ones defined in `install_requires` in setup.py) # # order specific notes: # - tqdm must be checked before tokenizers pkgs_to_check_at_runtime = "python requests filelock numpy".split() for pkg in pkgs_to_check_at_runtime: if pkg in deps: require_version_core(deps[pkg]) else: raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") def dep_version_check(pkg, hint=None): require_version(deps[pkg], hint) ================================================ FILE: src/diffusers/dependency_versions_table.py ================================================ # THIS FILE HAS BEEN AUTOGENERATED. To update: # 1. modify the `_deps` dict in setup.py # 2. run `make deps_table_update` deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.31.0", "compel": "compel==0.1.8", "datasets": "datasets", "filelock": "filelock", "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", "huggingface-hub": "huggingface-hub>=0.23.2", "requests-mock": "requests-mock==1.10.0", "importlib_metadata": "importlib_metadata", "invisible-watermark": "invisible-watermark>=0.2.0", "isort": "isort>=5.5.4", "jax": "jax>=0.4.1", "jaxlib": "jaxlib>=0.4.1", "Jinja2": "Jinja2", "k-diffusion": "k-diffusion>=0.0.12", "torchsde": "torchsde", "note_seq": "note_seq", "librosa": "librosa", "numpy": "numpy", "parameterized": "parameterized", "peft": "peft>=0.6.0", "protobuf": "protobuf>=3.20.3,<4", "pytest": "pytest", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", "python": "python>=3.8.0", "ruff": "ruff==0.1.5", "safetensors": "safetensors>=0.3.1", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "GitPython": "GitPython<3.1.19", "scipy": "scipy", "onnx": "onnx", "regex": "regex!=2019.12.17", "requests": "requests", "tensorboard": "tensorboard", "torch": "torch>=1.4", "torchvision": "torchvision", "transformers": "transformers>=4.41.2", "urllib3": "urllib3<=2.0.0", "black": "black", } ================================================ FILE: src/diffusers/experimental/README.md ================================================ # 🧨 Diffusers Experimental We are adding experimental code to support novel applications and usages of the Diffusers library. Currently, the following experiments are supported: * Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model. ================================================ FILE: src/diffusers/experimental/__init__.py ================================================ from .rl import ValueGuidedRLPipeline ================================================ FILE: src/diffusers/experimental/rl/__init__.py ================================================ from .value_guided_sampling import ValueGuidedRLPipeline ================================================ FILE: src/diffusers/experimental/rl/value_guided_sampling.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch import tqdm from ...models.unets.unet_1d import UNet1DModel from ...pipelines import DiffusionPipeline from ...utils.dummy_pt_objects import DDPMScheduler from ...utils.torch_utils import randn_tensor class ValueGuidedRLPipeline(DiffusionPipeline): r""" Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward. unet ([`UNet1DModel`]): UNet architecture to denoise the encoded trajectories. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this application is [`DDPMScheduler`]. env (): An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models. """ def __init__( self, value_function: UNet1DModel, unet: UNet1DModel, scheduler: DDPMScheduler, env, ): super().__init__() self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env) self.data = env.get_dataset() self.means = {} for key in self.data.keys(): try: self.means[key] = self.data[key].mean() except: # noqa: E722 pass self.stds = {} for key in self.data.keys(): try: self.stds[key] = self.data[key].std() except: # noqa: E722 pass self.state_dim = env.observation_space.shape[0] self.action_dim = env.action_space.shape[0] def normalize(self, x_in, key): return (x_in - self.means[key]) / self.stds[key] def de_normalize(self, x_in, key): return x_in * self.stds[key] + self.means[key] def to_torch(self, x_in): if isinstance(x_in, dict): return {k: self.to_torch(v) for k, v in x_in.items()} elif torch.is_tensor(x_in): return x_in.to(self.unet.device) return torch.tensor(x_in, device=self.unet.device) def reset_x0(self, x_in, cond, act_dim): for key, val in cond.items(): x_in[:, key, act_dim:] = val.clone() return x_in def run_diffusion(self, x, conditions, n_guide_steps, scale): batch_size = x.shape[0] y = None for i in tqdm.tqdm(self.scheduler.timesteps): # create batch of timesteps to pass into model timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) for _ in range(n_guide_steps): with torch.enable_grad(): x.requires_grad_() # permute to match dimension for pre-trained models y = self.value_function(x.permute(0, 2, 1), timesteps).sample grad = torch.autograd.grad([y.sum()], [x])[0] posterior_variance = self.scheduler._get_variance(i) model_std = torch.exp(0.5 * posterior_variance) grad = model_std * grad grad[timesteps < 2] = 0 x = x.detach() x = x + scale * grad x = self.reset_x0(x, conditions, self.action_dim) prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) # TODO: verify deprecation of this kwarg x = self.scheduler.step(prev_x, i, x)["prev_sample"] # apply conditions to the trajectory (set the initial state) x = self.reset_x0(x, conditions, self.action_dim) x = self.to_torch(x) return x, y def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): # normalize the observations and create batch dimension obs = self.normalize(obs, "observations") obs = obs[None].repeat(batch_size, axis=0) conditions = {0: self.to_torch(obs)} shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) # generate initial noise and apply our conditions (to make the trajectories start at current state) x1 = randn_tensor(shape, device=self.unet.device) x = self.reset_x0(x1, conditions, self.action_dim) x = self.to_torch(x) # run the diffusion process x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) # sort output trajectories by value sorted_idx = y.argsort(0, descending=True).squeeze() sorted_values = x[sorted_idx] actions = sorted_values[:, :, : self.action_dim] actions = actions.detach().cpu().numpy() denorm_actions = self.de_normalize(actions, key="actions") # select the action with the highest value if y is not None: selected_index = 0 else: # if we didn't run value guiding, select a random action selected_index = np.random.randint(0, batch_size) denorm_actions = denorm_actions[selected_index, 0] return denorm_actions ================================================ FILE: src/diffusers/image_processor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import warnings from typing import List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from PIL import Image, ImageFilter, ImageOps from .configuration_utils import ConfigMixin, register_to_config from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate PipelineImageInput = Union[ PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], List[torch.Tensor], ] PipelineDepthInput = PipelineImageInput def is_valid_image(image): return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) def is_valid_image_imagelist(images): # check if the image input is one of the supported formats for image and image list: # it can be either one of below 3 # (1) a 4d pytorch tensor or numpy array, # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor # (3) a list of valid image if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: return True elif is_valid_image(images): return True elif isinstance(images, list): return all(is_valid_image(image) for image in images) return False class VaeImageProcessor(ConfigMixin): """ Image processor for VAE. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. vae_scale_factor (`int`, *optional*, defaults to `8`): VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. resample (`str`, *optional*, defaults to `lanczos`): Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1]. do_binarize (`bool`, *optional*, defaults to `False`): Whether to binarize the image to 0/1. do_convert_rgb (`bool`, *optional*, defaults to be `False`): Whether to convert the images to RGB format. do_convert_grayscale (`bool`, *optional*, defaults to be `False`): Whether to convert the images to grayscale format. """ config_name = CONFIG_NAME @register_to_config def __init__( self, do_resize: bool = True, vae_scale_factor: int = 8, vae_latent_channels: int = 4, resample: str = "lanczos", do_normalize: bool = True, do_binarize: bool = False, do_convert_rgb: bool = False, do_convert_grayscale: bool = False, ): super().__init__() if do_convert_rgb and do_convert_grayscale: raise ValueError( "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", ) @staticmethod def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images @staticmethod def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: """ Convert a PIL image or a list of PIL images to NumPy arrays. """ if not isinstance(images, list): images = [images] images = [np.array(image).astype(np.float32) / 255.0 for image in images] images = np.stack(images, axis=0) return images @staticmethod def numpy_to_pt(images: np.ndarray) -> torch.Tensor: """ Convert a NumPy image to a PyTorch tensor. """ if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images @staticmethod def pt_to_numpy(images: torch.Tensor) -> np.ndarray: """ Convert a PyTorch tensor to a NumPy image. """ images = images.cpu().permute(0, 2, 3, 1).float().numpy() return images @staticmethod def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: """ Normalize an image array to [-1,1]. """ return 2.0 * images - 1.0 @staticmethod def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: """ Denormalize an image array to [0,1]. """ return (images / 2 + 0.5).clamp(0, 1) @staticmethod def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: """ Converts a PIL image to RGB format. """ image = image.convert("RGB") return image @staticmethod def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: """ Converts a PIL image to grayscale format. """ image = image.convert("L") return image @staticmethod def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image: """ Applies Gaussian blur to an image. """ image = image.filter(ImageFilter.GaussianBlur(blur_factor)) return image @staticmethod def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0): """ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128. Args: mask_image (PIL.Image.Image): Mask image. width (int): Width of the image to be processed. height (int): Height of the image to be processed. pad (int, optional): Padding to be added to the crop region. Defaults to 0. Returns: tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio. """ mask_image = mask_image.convert("L") mask = np.array(mask_image) # 1. find a rectangular region that contains all masked ares in an image h, w = mask.shape crop_left = 0 for i in range(w): if not (mask[:, i] == 0).all(): break crop_left += 1 crop_right = 0 for i in reversed(range(w)): if not (mask[:, i] == 0).all(): break crop_right += 1 crop_top = 0 for i in range(h): if not (mask[i] == 0).all(): break crop_top += 1 crop_bottom = 0 for i in reversed(range(h)): if not (mask[i] == 0).all(): break crop_bottom += 1 # 2. add padding to the crop region x1, y1, x2, y2 = ( int(max(crop_left - pad, 0)), int(max(crop_top - pad, 0)), int(min(w - crop_right + pad, w)), int(min(h - crop_bottom + pad, h)), ) # 3. expands crop region to match the aspect ratio of the image to be processed ratio_crop_region = (x2 - x1) / (y2 - y1) ratio_processing = width / height if ratio_crop_region > ratio_processing: desired_height = (x2 - x1) / ratio_processing desired_height_diff = int(desired_height - (y2 - y1)) y1 -= desired_height_diff // 2 y2 += desired_height_diff - desired_height_diff // 2 if y2 >= mask_image.height: diff = y2 - mask_image.height y2 -= diff y1 -= diff if y1 < 0: y2 -= y1 y1 -= y1 if y2 >= mask_image.height: y2 = mask_image.height else: desired_width = (y2 - y1) * ratio_processing desired_width_diff = int(desired_width - (x2 - x1)) x1 -= desired_width_diff // 2 x2 += desired_width_diff - desired_width_diff // 2 if x2 >= mask_image.width: diff = x2 - mask_image.width x2 -= diff x1 -= diff if x1 < 0: x2 -= x1 x1 -= x1 if x2 >= mask_image.width: x2 = mask_image.width return x1, y1, x2, y2 def _resize_and_fill( self, image: PIL.Image.Image, width: int, height: int, ) -> PIL.Image.Image: """ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. Args: image: The image to resize. width: The width to resize the image to. height: The height to resize the image to. """ ratio = width / height src_ratio = image.width / image.height src_w = width if ratio < src_ratio else image.width * height // image.height src_h = height if ratio >= src_ratio else image.height * width // image.width resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) res = Image.new("RGB", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) if ratio < src_ratio: fill_height = height // 2 - src_h // 2 if fill_height > 0: res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) res.paste( resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h), ) elif ratio > src_ratio: fill_width = width // 2 - src_w // 2 if fill_width > 0: res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) res.paste( resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0), ) return res def _resize_and_crop( self, image: PIL.Image.Image, width: int, height: int, ) -> PIL.Image.Image: """ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. Args: image: The image to resize. width: The width to resize the image to. height: The height to resize the image to. """ ratio = width / height src_ratio = image.width / image.height src_w = width if ratio > src_ratio else image.width * height // image.height src_h = height if ratio <= src_ratio else image.height * width // image.width resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) res = Image.new("RGB", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) return res def resize( self, image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], height: int, width: int, resize_mode: str = "default", # "default", "fill", "crop" ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Resize image. Args: image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): The image input, can be a PIL image, numpy array or pytorch tensor. height (`int`): The height to resize to. width (`int`): The width to resize to. resize_mode (`str`, *optional*, defaults to `default`): The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only supported for PIL image input. Returns: `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: The resized image. """ if resize_mode != "default" and not isinstance(image, PIL.Image.Image): raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}") if isinstance(image, PIL.Image.Image): if resize_mode == "default": image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) elif resize_mode == "fill": image = self._resize_and_fill(image, width, height) elif resize_mode == "crop": image = self._resize_and_crop(image, width, height) else: raise ValueError(f"resize_mode {resize_mode} is not supported") elif isinstance(image, torch.Tensor): image = torch.nn.functional.interpolate( image, size=(height, width), ) elif isinstance(image, np.ndarray): image = self.numpy_to_pt(image) image = torch.nn.functional.interpolate( image, size=(height, width), ) image = self.pt_to_numpy(image) return image def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: """ Create a mask. Args: image (`PIL.Image.Image`): The image input, should be a PIL image. Returns: `PIL.Image.Image`: The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1. """ image[image < 0.5] = 0 image[image >= 0.5] = 1 return image def get_default_height_width( self, image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], height: Optional[int] = None, width: Optional[int] = None, ) -> Tuple[int, int]: """ This function return the height and width that are downscaled to the next integer multiple of `vae_scale_factor`. Args: image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should have shape `[batch, channel, height, width]`. height (`int`, *optional*, defaults to `None`): The height in preprocessed image. If `None`, will use the height of `image` input. width (`int`, *optional*`, defaults to `None`): The width in preprocessed. If `None`, will use the width of the `image` input. """ if height is None: if isinstance(image, PIL.Image.Image): height = image.height elif isinstance(image, torch.Tensor): height = image.shape[2] else: height = image.shape[1] if width is None: if isinstance(image, PIL.Image.Image): width = image.width elif isinstance(image, torch.Tensor): width = image.shape[3] else: width = image.shape[2] width, height = ( x - x % self.config.vae_scale_factor for x in (width, height) ) # resize to integer multiple of vae_scale_factor return height, width def preprocess( self, image: PipelineImageInput, height: Optional[int] = None, width: Optional[int] = None, resize_mode: str = "default", # "default", "fill", "crop" crops_coords: Optional[Tuple[int, int, int, int]] = None, ) -> torch.Tensor: """ Preprocess the image input. Args: image (`pipeline_image_input`): The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats. height (`int`, *optional*, defaults to `None`): The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height. width (`int`, *optional*`, defaults to `None`): The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. resize_mode (`str`, *optional*, defaults to `default`): The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only supported for PIL image input. crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): The crop coordinates for each image in the batch. If `None`, will not crop the image. """ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: if isinstance(image, torch.Tensor): # if image is a pytorch tensor could have 2 possible shapes: # 1. batch x height x width: we should insert the channel dimension at position 1 # 2. channel x height x width: we should insert batch dimension at position 0, # however, since both channel and batch dimension has same size 1, it is same to insert at position 1 # for simplicity, we insert a dimension of size 1 at position 1 for both cases image = image.unsqueeze(1) else: # if it is a numpy array, it could have 2 possible shapes: # 1. batch x height x width: insert channel dimension on last position # 2. height x width x channel: insert batch dimension on first position if image.shape[-1] == 1: image = np.expand_dims(image, axis=0) else: image = np.expand_dims(image, axis=-1) if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: warnings.warn( "Passing `image` as a list of 4d np.ndarray is deprecated." "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", FutureWarning, ) image = np.concatenate(image, axis=0) if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: warnings.warn( "Passing `image` as a list of 4d torch.Tensor is deprecated." "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", FutureWarning, ) image = torch.cat(image, axis=0) if not is_valid_image_imagelist(image): raise ValueError( f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}" ) if not isinstance(image, list): image = [image] if isinstance(image[0], PIL.Image.Image): if crops_coords is not None: image = [i.crop(crops_coords) for i in image] if self.config.do_resize: height, width = self.get_default_height_width(image[0], height, width) image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image] if self.config.do_convert_rgb: image = [self.convert_to_rgb(i) for i in image] elif self.config.do_convert_grayscale: image = [self.convert_to_grayscale(i) for i in image] image = self.pil_to_numpy(image) # to np image = self.numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = self.numpy_to_pt(image) height, width = self.get_default_height_width(image, height, width) if self.config.do_resize: image = self.resize(image, height, width) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) if self.config.do_convert_grayscale and image.ndim == 3: image = image.unsqueeze(1) channel = image.shape[1] # don't need any preprocess if the image is latents if channel == self.vae_latent_channels: return image height, width = self.get_default_height_width(image, height, width) if self.config.do_resize: image = self.resize(image, height, width) # expected range [0,1], normalize to [-1,1] do_normalize = self.config.do_normalize if do_normalize and image.min() < 0: warnings.warn( "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", FutureWarning, ) do_normalize = False if do_normalize: image = self.normalize(image) if self.config.do_binarize: image = self.binarize(image) return image def postprocess( self, image: torch.Tensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Postprocess the image output from tensor to `output_type`. Args: image (`torch.Tensor`): The image input, should be a pytorch tensor with shape `B x C x H x W`. output_type (`str`, *optional*, defaults to `pil`): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. do_denormalize (`List[bool]`, *optional*, defaults to `None`): Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the `VaeImageProcessor` config. Returns: `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: The postprocessed image. """ if not isinstance(image, torch.Tensor): raise ValueError( f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" ) if output_type not in ["latent", "pt", "np", "pil"]: deprecation_message = ( f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " "`pil`, `np`, `pt`, `latent`" ) deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) output_type = "np" if output_type == "latent": return image if do_denormalize is None: do_denormalize = [self.config.do_normalize] * image.shape[0] image = torch.stack( [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] ) if output_type == "pt": return image image = self.pt_to_numpy(image) if output_type == "np": return image if output_type == "pil": return self.numpy_to_pil(image) def apply_overlay( self, mask: PIL.Image.Image, init_image: PIL.Image.Image, image: PIL.Image.Image, crop_coords: Optional[Tuple[int, int, int, int]] = None, ) -> PIL.Image.Image: """ overlay the inpaint output to the original image """ width, height = image.width, image.height init_image = self.resize(init_image, width=width, height=height) mask = self.resize(mask, width=width, height=height) init_image_masked = PIL.Image.new("RGBa", (width, height)) init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L"))) init_image_masked = init_image_masked.convert("RGBA") if crop_coords is not None: x, y, x2, y2 = crop_coords w = x2 - x h = y2 - y base_image = PIL.Image.new("RGBA", (width, height)) image = self.resize(image, height=h, width=w, resize_mode="crop") base_image.paste(image, (x, y)) image = base_image.convert("RGB") image = image.convert("RGBA") image.alpha_composite(init_image_masked) image = image.convert("RGB") return image class VaeImageProcessorLDM3D(VaeImageProcessor): """ Image processor for VAE LDM3D. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. vae_scale_factor (`int`, *optional*, defaults to `8`): VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. resample (`str`, *optional*, defaults to `lanczos`): Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1]. """ config_name = CONFIG_NAME @register_to_config def __init__( self, do_resize: bool = True, vae_scale_factor: int = 8, resample: str = "lanczos", do_normalize: bool = True, ): super().__init__() @staticmethod def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: """ Convert a NumPy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image[:, :, :3]) for image in images] return pil_images @staticmethod def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: """ Convert a PIL image or a list of PIL images to NumPy arrays. """ if not isinstance(images, list): images = [images] images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images] images = np.stack(images, axis=0) return images @staticmethod def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: """ Args: image: RGB-like depth image Returns: depth map """ return image[:, :, 1] * 2**8 + image[:, :, 2] def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]: """ Convert a NumPy depth image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images_depth = images[:, :, :, 3:] if images.shape[-1] == 6: images_depth = (images_depth * 255).round().astype("uint8") pil_images = [ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth ] elif images.shape[-1] == 4: images_depth = (images_depth * 65535.0).astype(np.uint16) pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth] else: raise Exception("Not supported") return pil_images def postprocess( self, image: torch.Tensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Postprocess the image output from tensor to `output_type`. Args: image (`torch.Tensor`): The image input, should be a pytorch tensor with shape `B x C x H x W`. output_type (`str`, *optional*, defaults to `pil`): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. do_denormalize (`List[bool]`, *optional*, defaults to `None`): Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the `VaeImageProcessor` config. Returns: `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: The postprocessed image. """ if not isinstance(image, torch.Tensor): raise ValueError( f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" ) if output_type not in ["latent", "pt", "np", "pil"]: deprecation_message = ( f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " "`pil`, `np`, `pt`, `latent`" ) deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) output_type = "np" if do_denormalize is None: do_denormalize = [self.config.do_normalize] * image.shape[0] image = torch.stack( [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] ) image = self.pt_to_numpy(image) if output_type == "np": if image.shape[-1] == 6: image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0) else: image_depth = image[:, :, :, 3:] return image[:, :, :, :3], image_depth if output_type == "pil": return self.numpy_to_pil(image), self.numpy_to_depth(image) else: raise Exception(f"This type {output_type} is not supported") def preprocess( self, rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray], depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray], height: Optional[int] = None, width: Optional[int] = None, target_res: Optional[int] = None, ) -> torch.Tensor: """ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. """ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3: raise Exception("This is not yet supported") if isinstance(rgb, supported_formats): rgb = [rgb] depth = [depth] elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)): raise ValueError( f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}" ) if isinstance(rgb[0], PIL.Image.Image): if self.config.do_convert_rgb: raise Exception("This is not yet supported") # rgb = [self.convert_to_rgb(i) for i in rgb] # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth if self.config.do_resize or target_res: height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res rgb = [self.resize(i, height, width) for i in rgb] depth = [self.resize(i, height, width) for i in depth] rgb = self.pil_to_numpy(rgb) # to np rgb = self.numpy_to_pt(rgb) # to pt depth = self.depth_pil_to_numpy(depth) # to np depth = self.numpy_to_pt(depth) # to pt elif isinstance(rgb[0], np.ndarray): rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0) rgb = self.numpy_to_pt(rgb) height, width = self.get_default_height_width(rgb, height, width) if self.config.do_resize: rgb = self.resize(rgb, height, width) depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0) depth = self.numpy_to_pt(depth) height, width = self.get_default_height_width(depth, height, width) if self.config.do_resize: depth = self.resize(depth, height, width) elif isinstance(rgb[0], torch.Tensor): raise Exception("This is not yet supported") # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0) # if self.config.do_convert_grayscale and rgb.ndim == 3: # rgb = rgb.unsqueeze(1) # channel = rgb.shape[1] # height, width = self.get_default_height_width(rgb, height, width) # if self.config.do_resize: # rgb = self.resize(rgb, height, width) # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0) # if self.config.do_convert_grayscale and depth.ndim == 3: # depth = depth.unsqueeze(1) # channel = depth.shape[1] # # don't need any preprocess if the image is latents # if depth == 4: # return rgb, depth # height, width = self.get_default_height_width(depth, height, width) # if self.config.do_resize: # depth = self.resize(depth, height, width) # expected range [0,1], normalize to [-1,1] do_normalize = self.config.do_normalize if rgb.min() < 0 and do_normalize: warnings.warn( "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]", FutureWarning, ) do_normalize = False if do_normalize: rgb = self.normalize(rgb) depth = self.normalize(depth) if self.config.do_binarize: rgb = self.binarize(rgb) depth = self.binarize(depth) return rgb, depth class IPAdapterMaskProcessor(VaeImageProcessor): """ Image processor for IP Adapter image masks. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. vae_scale_factor (`int`, *optional*, defaults to `8`): VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. resample (`str`, *optional*, defaults to `lanczos`): Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `False`): Whether to normalize the image to [-1,1]. do_binarize (`bool`, *optional*, defaults to `True`): Whether to binarize the image to 0/1. do_convert_grayscale (`bool`, *optional*, defaults to be `True`): Whether to convert the images to grayscale format. """ config_name = CONFIG_NAME @register_to_config def __init__( self, do_resize: bool = True, vae_scale_factor: int = 8, resample: str = "lanczos", do_normalize: bool = False, do_binarize: bool = True, do_convert_grayscale: bool = True, ): super().__init__( do_resize=do_resize, vae_scale_factor=vae_scale_factor, resample=resample, do_normalize=do_normalize, do_binarize=do_binarize, do_convert_grayscale=do_convert_grayscale, ) @staticmethod def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int): """ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. Args: mask (`torch.Tensor`): The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. batch_size (`int`): The batch size. num_queries (`int`): The number of queries. value_embed_dim (`int`): The dimensionality of the value embeddings. Returns: `torch.Tensor`: The downsampled mask tensor. """ o_h = mask.shape[1] o_w = mask.shape[2] ratio = o_w / o_h mask_h = int(math.sqrt(num_queries / ratio)) mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) mask_w = num_queries // mask_h mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) # Repeat batch_size times if mask_downsample.shape[0] < batch_size: mask_downsample = mask_downsample.repeat(batch_size, 1, 1) mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) downsampled_area = mask_h * mask_w # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries if downsampled_area < num_queries: warnings.warn( "The aspect ratio of the mask does not match the aspect ratio of the output image. " "Please update your masks or adjust the output size for optimal performance.", UserWarning, ) mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0) # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries if downsampled_area > num_queries: warnings.warn( "The aspect ratio of the mask does not match the aspect ratio of the output image. " "Please update your masks or adjust the output size for optimal performance.", UserWarning, ) mask_downsample = mask_downsample[:, :num_queries] # Repeat last dimension to match SDPA output shape mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( 1, 1, value_embed_dim ) return mask_downsample class PixArtImageProcessor(VaeImageProcessor): """ Image processor for PixArt image resize and crop. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. vae_scale_factor (`int`, *optional*, defaults to `8`): VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. resample (`str`, *optional*, defaults to `lanczos`): Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1]. do_binarize (`bool`, *optional*, defaults to `False`): Whether to binarize the image to 0/1. do_convert_rgb (`bool`, *optional*, defaults to be `False`): Whether to convert the images to RGB format. do_convert_grayscale (`bool`, *optional*, defaults to be `False`): Whether to convert the images to grayscale format. """ @register_to_config def __init__( self, do_resize: bool = True, vae_scale_factor: int = 8, resample: str = "lanczos", do_normalize: bool = True, do_binarize: bool = False, do_convert_grayscale: bool = False, ): super().__init__( do_resize=do_resize, vae_scale_factor=vae_scale_factor, resample=resample, do_normalize=do_normalize, do_binarize=do_binarize, do_convert_grayscale=do_convert_grayscale, ) @staticmethod def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: """Returns binned height and width.""" ar = float(height / width) closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) default_hw = ratios[closest_ratio] return int(default_hw[0]), int(default_hw[1]) @staticmethod def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor: orig_height, orig_width = samples.shape[2], samples.shape[3] # Check if resizing is needed if orig_height != new_height or orig_width != new_width: ratio = max(new_height / orig_height, new_width / orig_width) resized_width = int(orig_width * ratio) resized_height = int(orig_height * ratio) # Resize samples = F.interpolate( samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False ) # Center Crop start_x = (resized_width - new_width) // 2 end_x = start_x + new_width start_y = (resized_height - new_height) // 2 end_y = start_y + new_height samples = samples[:, :, start_y:end_y, start_x:end_x] return samples ================================================ FILE: src/diffusers/loaders/__init__.py ================================================ from typing import TYPE_CHECKING from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate from ..utils.import_utils import is_peft_available, is_torch_available, is_transformers_available def text_encoder_lora_state_dict(text_encoder): deprecate( "text_encoder_load_state_dict in `models`", "0.27.0", "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", ) state_dict = {} for name, module in text_encoder_attn_modules(text_encoder): for k, v in module.q_proj.lora_linear_layer.state_dict().items(): state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v for k, v in module.k_proj.lora_linear_layer.state_dict().items(): state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v for k, v in module.v_proj.lora_linear_layer.state_dict().items(): state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v for k, v in module.out_proj.lora_linear_layer.state_dict().items(): state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v return state_dict if is_transformers_available(): def text_encoder_attn_modules(text_encoder): deprecate( "text_encoder_attn_modules in `models`", "0.27.0", "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", ) from transformers import CLIPTextModel, CLIPTextModelWithProjection attn_modules = [] if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): for i, layer in enumerate(text_encoder.text_model.encoder.layers): name = f"text_model.encoder.layers.{i}.self_attn" mod = layer.self_attn attn_modules.append((name, mod)) else: raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") return attn_modules _import_structure = {} if is_torch_available(): _import_structure["single_file_model"] = ["FromOriginalModelMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] if is_transformers_available(): _import_structure["single_file"] = ["FromSingleFileMixin"] _import_structure["lora_pipeline"] = [ "AmusedLoraLoaderMixin", "StableDiffusionLoraLoaderMixin", "SD3LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "LoraLoaderMixin", "FluxLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"] _import_structure["peft"] = ["PeftAdapterMixin"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .single_file_model import FromOriginalModelMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers if is_transformers_available(): from .ip_adapter import IPAdapterMixin from .lora_pipeline import ( AmusedLoraLoaderMixin, FluxLoraLoaderMixin, LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin from .peft import PeftAdapterMixin else: import sys sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) ================================================ FILE: src/diffusers/loaders/ip_adapter.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from pathlib import Path from typing import Dict, List, Optional, Union import torch import torch.nn.functional as F from huggingface_hub.utils import validate_hf_hub_args from safetensors import safe_open from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict from ..utils import ( USE_PEFT_BACKEND, _get_model_file, is_accelerate_available, is_torch_version, is_transformers_available, logging, ) from .unet_loader_utils import _maybe_expand_lora_scales if is_transformers_available(): from transformers import ( CLIPImageProcessor, CLIPVisionModelWithProjection, ) from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) logger = logging.get_logger(__name__) class IPAdapterMixin: """Mixin for handling IP Adapters.""" @validate_hf_hub_args def load_ip_adapter( self, pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], subfolder: Union[str, List[str]], weight_name: Union[str, List[str]], image_encoder_folder: Optional[str] = "image_encoder", **kwargs, ): """ Parameters: pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). subfolder (`str` or `List[str]`): The subfolder location of a model file within a larger model repository on the Hub or locally. If a list is passed, it should have the same length as `weight_name`. weight_name (`str` or `List[str]`): The name of the weight file to load. If a list is passed, it should have the same length as `weight_name`. image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): The subfolder location of the image encoder within a larger model repository on the Hub or locally. Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, `image_encoder_folder="different_subfolder/image_encoder"`. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. """ # handle the list inputs for multiple IP Adapters if not isinstance(weight_name, list): weight_name = [weight_name] if not isinstance(pretrained_model_name_or_path_or_dict, list): pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] if len(pretrained_model_name_or_path_or_dict) == 1: pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) if not isinstance(subfolder, list): subfolder = [subfolder] if len(subfolder) == 1: subfolder = subfolder * len(weight_name) if len(weight_name) != len(pretrained_model_name_or_path_or_dict): raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") if len(weight_name) != len(subfolder): raise ValueError("`weight_name` and `subfolder` must have the same length.") # Load the main state dict first. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" " install accelerate\n```\n." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." ) user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } state_dicts = [] for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( pretrained_model_name_or_path_or_dict, weight_name, subfolder ): if not isinstance(pretrained_model_name_or_path_or_dict, dict): model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) if weight_name.endswith(".safetensors"): state_dict = {"image_proj": {}, "ip_adapter": {}} with safe_open(model_file, framework="pt", device="cpu") as f: for key in f.keys(): if key.startswith("image_proj."): state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) elif key.startswith("ip_adapter."): state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) else: state_dict = load_state_dict(model_file) else: state_dict = pretrained_model_name_or_path_or_dict keys = list(state_dict.keys()) if keys != ["image_proj", "ip_adapter"]: raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") state_dicts.append(state_dict) # load CLIP image encoder here if it has not been registered to the pipeline yet if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: if image_encoder_folder is not None: if not isinstance(pretrained_model_name_or_path_or_dict, dict): logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") if image_encoder_folder.count("/") == 0: image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix() else: image_encoder_subfolder = Path(image_encoder_folder).as_posix() image_encoder = CLIPVisionModelWithProjection.from_pretrained( pretrained_model_name_or_path_or_dict, subfolder=image_encoder_subfolder, low_cpu_mem_usage=low_cpu_mem_usage, ).to(self.device, dtype=self.dtype) self.register_modules(image_encoder=image_encoder) else: raise ValueError( "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." ) else: logger.warning( "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." ) # create feature extractor if it has not been registered to the pipeline yet if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: clip_image_size = self.image_encoder.config.image_size feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size) self.register_modules(feature_extractor=feature_extractor) # load ip-adapter into unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) extra_loras = unet._load_ip_adapter_loras(state_dicts) if extra_loras != {}: if not USE_PEFT_BACKEND: logger.warning("PEFT backend is required to load these weights.") else: # apply the IP Adapter Face ID LoRA weights peft_config = getattr(unet, "peft_config", {}) for k, lora in extra_loras.items(): if f"faceid_{k}" not in peft_config: self.load_lora_weights(lora, adapter_name=f"faceid_{k}") self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0]) def set_ip_adapter_scale(self, scale): """ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for granular control over each IP-Adapter behavior. A config can be a float or a dictionary. Example: ```py # To use original IP-Adapter scale = 1.0 pipeline.set_ip_adapter_scale(scale) # To use style block only scale = { "up": {"block_0": [0.0, 1.0, 0.0]}, } pipeline.set_ip_adapter_scale(scale) # To use style+layout blocks scale = { "down": {"block_2": [0.0, 1.0]}, "up": {"block_0": [0.0, 1.0, 0.0]}, } pipeline.set_ip_adapter_scale(scale) # To use style and layout from 2 reference images scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}] pipeline.set_ip_adapter_scale(scales) ``` """ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet if not isinstance(scale, list): scale = [scale] scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0) for attn_name, attn_processor in unet.attn_processors.items(): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): if len(scale_configs) != len(attn_processor.scale): raise ValueError( f"Cannot assign {len(scale_configs)} scale_configs to " f"{len(attn_processor.scale)} IP-Adapter." ) elif len(scale_configs) == 1: scale_configs = scale_configs * len(attn_processor.scale) for i, scale_config in enumerate(scale_configs): if isinstance(scale_config, dict): for k, s in scale_config.items(): if attn_name.startswith(k): attn_processor.scale[i] = s else: attn_processor.scale[i] = scale_config def unload_ip_adapter(self): """ Unloads the IP Adapter weights Examples: ```python >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. >>> pipeline.unload_ip_adapter() >>> ... ``` """ # remove CLIP image encoder if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: self.image_encoder = None self.register_to_config(image_encoder=[None, None]) # remove feature extractor only when safety_checker is None as safety_checker uses # the feature_extractor later if not hasattr(self, "safety_checker"): if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: self.feature_extractor = None self.register_to_config(feature_extractor=[None, None]) # remove hidden encoder self.unet.encoder_hid_proj = None self.unet.config.encoder_hid_dim_type = None # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj` if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None: self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj self.unet.text_encoder_hid_proj = None self.unet.config.encoder_hid_dim_type = "text_proj" # restore original Unet attention processors layers attn_procs = {} for name, value in self.unet.attn_processors.items(): attn_processor_class = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() ) attn_procs[name] = ( attn_processor_class if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)) else value.__class__() ) self.unet.set_attn_processor(attn_procs) ================================================ FILE: src/diffusers/loaders/lora_base.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import inspect import os from pathlib import Path from typing import Callable, Dict, List, Optional, Union import safetensors import torch import torch.nn as nn from huggingface_hub import model_info from huggingface_hub.constants import HF_HUB_OFFLINE from ..models.modeling_utils import ModelMixin, load_state_dict from ..utils import ( USE_PEFT_BACKEND, _get_model_file, delete_adapter_layers, deprecate, is_accelerate_available, is_peft_available, is_transformers_available, logging, recurse_remove_peft_layers, set_adapter_layers, set_weights_and_activate_adapters, ) if is_transformers_available(): from transformers import PreTrainedModel if is_peft_available(): from peft.tuners.tuners_utils import BaseTunerLayer if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module logger = logging.get_logger(__name__) def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): """ Fuses LoRAs for the text encoder. Args: text_encoder (`torch.nn.Module`): The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` attribute. lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. safe_fusing (`bool`, defaults to `False`): Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. adapter_names (`List[str]` or `str`): The names of the adapters to use. """ merge_kwargs = {"safe_merge": safe_fusing} for module in text_encoder.modules(): if isinstance(module, BaseTunerLayer): if lora_scale != 1.0: module.scale_layer(lora_scale) # For BC with previous PEFT versions, we need to check the signature # of the `merge` method to see if it supports the `adapter_names` argument. supported_merge_kwargs = list(inspect.signature(module.merge).parameters) if "adapter_names" in supported_merge_kwargs: merge_kwargs["adapter_names"] = adapter_names elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: raise ValueError( "The `adapter_names` argument is not supported with your PEFT version. " "Please upgrade to the latest version of PEFT. `pip install -U peft`" ) module.merge(**merge_kwargs) def unfuse_text_encoder_lora(text_encoder): """ Unfuses LoRAs for the text encoder. Args: text_encoder (`torch.nn.Module`): The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` attribute. """ for module in text_encoder.modules(): if isinstance(module, BaseTunerLayer): module.unmerge() def set_adapters_for_text_encoder( adapter_names: Union[List[str], str], text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821 text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None, ): """ Sets the adapter layers for the text encoder. Args: adapter_names (`List[str]` or `str`): The names of the adapters to use. text_encoder (`torch.nn.Module`, *optional*): The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` attribute. text_encoder_weights (`List[float]`, *optional*): The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. """ if text_encoder is None: raise ValueError( "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead." ) def process_weights(adapter_names, weights): # Expand weights into a list, one entry per adapter # e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None] if not isinstance(weights, list): weights = [weights] * len(adapter_names) if len(adapter_names) != len(weights): raise ValueError( f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}" ) # Set None values to default of 1.0 # e.g. [7,7] -> [7,7] ; [3, None] -> [3,1] weights = [w if w is not None else 1.0 for w in weights] return weights adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names text_encoder_weights = process_weights(adapter_names, text_encoder_weights) set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None): """ Disables the LoRA layers for the text encoder. Args: text_encoder (`torch.nn.Module`, *optional*): The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder` attribute. """ if text_encoder is None: raise ValueError("Text Encoder not found.") set_adapter_layers(text_encoder, enabled=False) def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None): """ Enables the LoRA layers for the text encoder. Args: text_encoder (`torch.nn.Module`, *optional*): The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` attribute. """ if text_encoder is None: raise ValueError("Text Encoder not found.") set_adapter_layers(text_encoder, enabled=True) def _remove_text_encoder_monkey_patch(text_encoder): recurse_remove_peft_layers(text_encoder) if getattr(text_encoder, "peft_config", None) is not None: del text_encoder.peft_config text_encoder._hf_peft_config_loaded = None class LoraBaseMixin: """Utility class for handling LoRAs.""" _lora_loadable_modules = [] num_fused_loras = 0 def load_lora_weights(self, **kwargs): raise NotImplementedError("`load_lora_weights()` is not implemented.") @classmethod def save_lora_weights(cls, **kwargs): raise NotImplementedError("`save_lora_weights()` not implemented.") @classmethod def lora_state_dict(cls, **kwargs): raise NotImplementedError("`lora_state_dict()` is not implemented.") @classmethod def _optionally_disable_offloading(cls, _pipeline): """ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. Args: _pipeline (`DiffusionPipeline`): The pipeline to disable offloading for. Returns: tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. """ is_model_cpu_offload = False is_sequential_cpu_offload = False if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) if not is_sequential_cpu_offload: is_sequential_cpu_offload = ( isinstance(component._hf_hook, AlignDevicesHook) or hasattr(component._hf_hook, "hooks") and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) ) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) remove_hook_from_module(component, recurse=is_sequential_cpu_offload) return (is_model_cpu_offload, is_sequential_cpu_offload) @classmethod def _fetch_state_dict( cls, pretrained_model_name_or_path_or_dict, weight_name, use_safetensors, local_files_only, cache_dir, force_download, proxies, token, revision, subfolder, user_agent, allow_pickle, ): from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: # Here we're relaxing the loading check to enable more Inference API # friendliness where sometimes, it's not at all possible to automatically # determine `weight_name`. if weight_name is None: weight_name = cls._best_guess_weight_name( pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=local_files_only, ) model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except (IOError, safetensors.SafetensorError) as e: if not allow_pickle: raise e # try loading non-safetensors weights model_file = None pass if model_file is None: if weight_name is None: weight_name = cls._best_guess_weight_name( pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only ) model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = load_state_dict(model_file) else: state_dict = pretrained_model_name_or_path_or_dict return state_dict @classmethod def _best_guess_weight_name( cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False ): from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE if local_files_only or HF_HUB_OFFLINE: raise ValueError("When using the offline mode, you must specify a `weight_name`.") targeted_files = [] if os.path.isfile(pretrained_model_name_or_path_or_dict): return elif os.path.isdir(pretrained_model_name_or_path_or_dict): targeted_files = [ f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension) ] else: files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] if len(targeted_files) == 0: return # "scheduler" does not correspond to a LoRA checkpoint. # "optimizer" does not correspond to a LoRA checkpoint # only top-level checkpoints are considered and not the other ones, hence "checkpoint". unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} targeted_files = list( filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) ) if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) if len(targeted_files) > 1: raise ValueError( f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." ) weight_name = targeted_files[0] return weight_name def unload_lora_weights(self): """ Unloads the LoRA parameters. Examples: ```python >>> # Assuming `pipeline` is already loaded with the LoRA parameters. >>> pipeline.unload_lora_weights() >>> ... ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") for component in self._lora_loadable_modules: model = getattr(self, component, None) if model is not None: if issubclass(model.__class__, ModelMixin): model.unload_lora() elif issubclass(model.__class__, PreTrainedModel): _remove_text_encoder_monkey_patch(model) def fuse_lora( self, components: List[str] = [], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, **kwargs, ): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. This is an experimental API. Args: components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. safe_fusing (`bool`, defaults to `False`): Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. adapter_names (`List[str]`, *optional*): Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. Example: ```py from diffusers import DiffusionPipeline import torch pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") pipeline.fuse_lora(lora_scale=0.7) ``` """ if "fuse_unet" in kwargs: depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version." deprecate( "fuse_unet", "1.0.0", depr_message, ) if "fuse_transformer" in kwargs: depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version." deprecate( "fuse_transformer", "1.0.0", depr_message, ) if "fuse_text_encoder" in kwargs: depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version." deprecate( "fuse_text_encoder", "1.0.0", depr_message, ) if len(components) == 0: raise ValueError("`components` cannot be an empty list.") for fuse_component in components: if fuse_component not in self._lora_loadable_modules: raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") model = getattr(self, fuse_component, None) if model is not None: # check if diffusers model if issubclass(model.__class__, ModelMixin): model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) # handle transformers models. if issubclass(model.__class__, PreTrainedModel): fuse_text_encoder_lora( model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) self.num_fused_loras += 1 def unfuse_lora(self, components: List[str] = [], **kwargs): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). This is an experimental API. Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_text_encoder (`bool`, defaults to `True`): Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ if "unfuse_unet" in kwargs: depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version." deprecate( "unfuse_unet", "1.0.0", depr_message, ) if "unfuse_transformer" in kwargs: depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version." deprecate( "unfuse_transformer", "1.0.0", depr_message, ) if "unfuse_text_encoder" in kwargs: depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version." deprecate( "unfuse_text_encoder", "1.0.0", depr_message, ) if len(components) == 0: raise ValueError("`components` cannot be an empty list.") for fuse_component in components: if fuse_component not in self._lora_loadable_modules: raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") model = getattr(self, fuse_component, None) if model is not None: if issubclass(model.__class__, (ModelMixin, PreTrainedModel)): for module in model.modules(): if isinstance(module, BaseTunerLayer): module.unmerge() self.num_fused_loras -= 1 def set_adapters( self, adapter_names: Union[List[str], str], adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None, ): adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names adapter_weights = copy.deepcopy(adapter_weights) # Expand weights into a list, one entry per adapter if not isinstance(adapter_weights, list): adapter_weights = [adapter_weights] * len(adapter_names) if len(adapter_names) != len(adapter_weights): raise ValueError( f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}" ) list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} all_adapters = { adapter for adapters in list_adapters.values() for adapter in adapters } # eg ["adapter1", "adapter2"] invert_list_adapters = { adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] for adapter in all_adapters } # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} # Decompose weights into weights for denoiser and text encoders. _component_adapter_weights = {} for component in self._lora_loadable_modules: model = getattr(self, component) for adapter_name, weights in zip(adapter_names, adapter_weights): if isinstance(weights, dict): component_adapter_weights = weights.pop(component, None) if component_adapter_weights is not None and not hasattr(self, component): logger.warning( f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}." ) if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]: logger.warning( ( f"Lora weight dict for adapter '{adapter_name}' contains {component}," f"but this will be ignored because {adapter_name} does not contain weights for {component}." f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}." ) ) else: component_adapter_weights = weights _component_adapter_weights.setdefault(component, []) _component_adapter_weights[component].append(component_adapter_weights) if issubclass(model.__class__, ModelMixin): model.set_adapters(adapter_names, _component_adapter_weights[component]) elif issubclass(model.__class__, PreTrainedModel): set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component]) def disable_lora(self): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") for component in self._lora_loadable_modules: model = getattr(self, component, None) if model is not None: if issubclass(model.__class__, ModelMixin): model.disable_lora() elif issubclass(model.__class__, PreTrainedModel): disable_lora_for_text_encoder(model) def enable_lora(self): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") for component in self._lora_loadable_modules: model = getattr(self, component, None) if model is not None: if issubclass(model.__class__, ModelMixin): model.enable_lora() elif issubclass(model.__class__, PreTrainedModel): enable_lora_for_text_encoder(model) def delete_adapters(self, adapter_names: Union[List[str], str]): """ Args: Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s). adapter_names (`Union[List[str], str]`): The names of the adapter to delete. Can be a single string or a list of strings """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") if isinstance(adapter_names, str): adapter_names = [adapter_names] for component in self._lora_loadable_modules: model = getattr(self, component, None) if model is not None: if issubclass(model.__class__, ModelMixin): model.delete_adapters(adapter_names) elif issubclass(model.__class__, PreTrainedModel): for adapter_name in adapter_names: delete_adapter_layers(model, adapter_name) def get_active_adapters(self) -> List[str]: """ Gets the list of the current active adapters. Example: ```python from diffusers import DiffusionPipeline pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", ).to("cuda") pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy") pipeline.get_active_adapters() ``` """ if not USE_PEFT_BACKEND: raise ValueError( "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" ) active_adapters = [] for component in self._lora_loadable_modules: model = getattr(self, component, None) if model is not None and issubclass(model.__class__, ModelMixin): for module in model.modules(): if isinstance(module, BaseTunerLayer): active_adapters = module.active_adapters break return active_adapters def get_list_adapters(self) -> Dict[str, List[str]]: """ Gets the current list of all available adapters in the pipeline. """ if not USE_PEFT_BACKEND: raise ValueError( "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`" ) set_adapters = {} for component in self._lora_loadable_modules: model = getattr(self, component, None) if ( model is not None and issubclass(model.__class__, (ModelMixin, PreTrainedModel)) and hasattr(model, "peft_config") ): set_adapters[component] = list(model.peft_config.keys()) return set_adapters def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None: """ Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case you want to load multiple adapters and free some GPU memory. Args: adapter_names (`List[str]`): List of adapters to send device to. device (`Union[torch.device, str, int]`): Device to send the adapters to. Can be either a torch device, a str or an integer. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") for component in self._lora_loadable_modules: model = getattr(self, component, None) if model is not None: for module in model.modules(): if isinstance(module, BaseTunerLayer): for adapter_name in adapter_names: module.lora_A[adapter_name].to(device) module.lora_B[adapter_name].to(device) # this is a param, not a module, so device placement is not in-place -> re-assign if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None: module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[ adapter_name ].to(device) @staticmethod def pack_weights(layers, prefix): layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} return layers_state_dict @staticmethod def write_lora_layers( state_dict: Dict[str, torch.Tensor], save_directory: str, is_main_process: bool, weight_name: str, save_function: Callable, safe_serialization: bool, ): from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return if save_function is None: if safe_serialization: def save_function(weights, filename): return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) else: save_function = torch.save os.makedirs(save_directory, exist_ok=True) if weight_name is None: if safe_serialization: weight_name = LORA_WEIGHT_NAME_SAFE else: weight_name = LORA_WEIGHT_NAME save_path = Path(save_directory, weight_name).as_posix() save_function(state_dict, save_path) logger.info(f"Model weights saved in {save_path}") @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 ================================================ FILE: src/diffusers/loaders/lora_conversion_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from ..utils import is_peft_version, logging logger = logging.get_logger(__name__) def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5): # 1. get all state_dict_keys all_keys = list(state_dict.keys()) sgm_patterns = ["input_blocks", "middle_block", "output_blocks"] # 2. check if needs remapping, if not return original dict is_in_sgm_format = False for key in all_keys: if any(p in key for p in sgm_patterns): is_in_sgm_format = True break if not is_in_sgm_format: return state_dict # 3. Else remap from SGM patterns new_state_dict = {} inner_block_map = ["resnets", "attentions", "upsamplers"] # Retrieves # of down, mid and up blocks input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() for layer in all_keys: if "text" in layer: new_state_dict[layer] = state_dict.pop(layer) else: layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) if sgm_patterns[0] in layer: input_block_ids.add(layer_id) elif sgm_patterns[1] in layer: middle_block_ids.add(layer_id) elif sgm_patterns[2] in layer: output_block_ids.add(layer_id) else: raise ValueError(f"Checkpoint not supported because layer {layer} not supported.") input_blocks = { layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] for layer_id in input_block_ids } middle_blocks = { layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] for layer_id in middle_block_ids } output_blocks = { layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key] for layer_id in output_block_ids } # Rename keys accordingly for i in input_block_ids: block_id = (i - 1) // (unet_config.layers_per_block + 1) layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1) for key in input_blocks[i]: inner_block_id = int(key.split(delimiter)[block_slice_pos]) inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" new_key = delimiter.join( key.split(delimiter)[: block_slice_pos - 1] + [str(block_id), inner_block_key, inner_layers_in_block] + key.split(delimiter)[block_slice_pos + 1 :] ) new_state_dict[new_key] = state_dict.pop(key) for i in middle_block_ids: key_part = None if i == 0: key_part = [inner_block_map[0], "0"] elif i == 1: key_part = [inner_block_map[1], "0"] elif i == 2: key_part = [inner_block_map[0], "1"] else: raise ValueError(f"Invalid middle block id {i}.") for key in middle_blocks[i]: new_key = delimiter.join( key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] ) new_state_dict[new_key] = state_dict.pop(key) for i in output_block_ids: block_id = i // (unet_config.layers_per_block + 1) layer_in_block_id = i % (unet_config.layers_per_block + 1) for key in output_blocks[i]: inner_block_id = int(key.split(delimiter)[block_slice_pos]) inner_block_key = inner_block_map[inner_block_id] inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0" new_key = delimiter.join( key.split(delimiter)[: block_slice_pos - 1] + [str(block_id), inner_block_key, inner_layers_in_block] + key.split(delimiter)[block_slice_pos + 1 :] ) new_state_dict[new_key] = state_dict.pop(key) if len(state_dict) > 0: raise ValueError("At this point all state dict entries have to be converted.") return new_state_dict def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): """ Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict. Args: state_dict (`dict`): The state dict to convert. unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet". text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to "text_encoder". Returns: `tuple`: A tuple containing the converted state dict and a dictionary of alphas. """ unet_state_dict = {} te_state_dict = {} te2_state_dict = {} network_alphas = {} # Check for DoRA-enabled LoRAs. dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict) dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) if dora_present_in_unet or dora_present_in_te or dora_present_in_te2: if is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) # Iterate over all LoRA weights. all_lora_keys = list(state_dict.keys()) for key in all_lora_keys: if not key.endswith("lora_down.weight"): continue # Extract LoRA name. lora_name = key.split(".")[0] # Find corresponding up weight and alpha. lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" # Handle U-Net LoRAs. if lora_name.startswith("lora_unet_"): diffusers_name = _convert_unet_lora_key(key) # Store down and up weights. unet_state_dict[diffusers_name] = state_dict.pop(key) unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) # Store DoRA scale if present. if dora_present_in_unet: dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." unet_state_dict[ diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) # Handle text encoder LoRAs. elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): diffusers_name = _convert_text_encoder_lora_key(key, lora_name) # Store down and up weights for te or te2. if lora_name.startswith(("lora_te_", "lora_te1_")): te_state_dict[diffusers_name] = state_dict.pop(key) te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) else: te2_state_dict[diffusers_name] = state_dict.pop(key) te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) # Store DoRA scale if present. if dora_present_in_te or dora_present_in_te2: dora_scale_key_to_replace_te = ( "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." ) if lora_name.startswith(("lora_te_", "lora_te1_")): te_state_dict[ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) elif lora_name.startswith("lora_te2_"): te2_state_dict[ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) # Store alpha if present. if lora_name_alpha in state_dict: alpha = state_dict.pop(lora_name_alpha).item() network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha)) # Check if any keys remain. if len(state_dict) > 0: raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") logger.info("Non-diffusers checkpoint detected.") # Construct final state dict. unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()} te2_state_dict = ( {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()} if len(te2_state_dict) > 0 else None ) if te2_state_dict is not None: te_state_dict.update(te2_state_dict) new_state_dict = {**unet_state_dict, **te_state_dict} return new_state_dict, network_alphas def _convert_unet_lora_key(key): """ Converts a U-Net LoRA key to a Diffusers compatible key. """ diffusers_name = key.replace("lora_unet_", "").replace("_", ".") # Replace common U-Net naming patterns. diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") diffusers_name = diffusers_name.replace("middle.block", "mid_block") diffusers_name = diffusers_name.replace("mid.block", "mid_block") diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") diffusers_name = diffusers_name.replace("proj.in", "proj_in") diffusers_name = diffusers_name.replace("proj.out", "proj_out") diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") # SDXL specific conversions. if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: pattern = r"\.\d+(?=\D*$)" diffusers_name = re.sub(pattern, "", diffusers_name, count=1) if ".in." in diffusers_name: diffusers_name = diffusers_name.replace("in.layers.2", "conv1") if ".out." in diffusers_name: diffusers_name = diffusers_name.replace("out.layers.3", "conv2") if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: diffusers_name = diffusers_name.replace("op", "conv") if "skip" in diffusers_name: diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") # LyCORIS specific conversions. if "time.emb.proj" in diffusers_name: diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") if "conv.shortcut" in diffusers_name: diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") # General conversions. if "transformer_blocks" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") elif "ff" in diffusers_name: pass elif any(key in diffusers_name for key in ("proj_in", "proj_out")): pass else: pass return diffusers_name def _convert_text_encoder_lora_key(key, lora_name): """ Converts a text encoder LoRA key to a Diffusers compatible key. """ if lora_name.startswith(("lora_te_", "lora_te1_")): key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" else: key_to_replace = "lora_te2_" diffusers_name = key.replace(key_to_replace, "").replace("_", ".") diffusers_name = diffusers_name.replace("text.model", "text_model") diffusers_name = diffusers_name.replace("self.attn", "self_attn") diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") diffusers_name = diffusers_name.replace("text.projection", "text_projection") if "self_attn" in diffusers_name or "text_projection" in diffusers_name: pass elif "mlp" in diffusers_name: # Be aware that this is the new diffusers convention and the rest of the code might # not utilize it yet. diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") return diffusers_name def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): """ Gets the correct alpha name for the Diffusers model. """ if lora_name_alpha.startswith("lora_unet_"): prefix = "unet." elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): prefix = "text_encoder." else: prefix = "text_encoder_2." new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" return {new_name: alpha} ================================================ FILE: src/diffusers/loaders/lora_pipeline.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Callable, Dict, List, Optional, Union import torch from huggingface_hub.utils import validate_hf_hub_args from ..utils import ( USE_PEFT_BACKEND, convert_state_dict_to_diffusers, convert_state_dict_to_peft, convert_unet_state_dict_to_peft, deprecate, get_adapter_name, get_peft_kwargs, is_peft_version, is_transformers_available, logging, scale_lora_layers, ) from .lora_base import LoraBaseMixin from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers if is_transformers_available(): from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules logger = logging.get_logger(__name__) TEXT_ENCODER_NAME = "text_encoder" UNET_NAME = "unet" TRANSFORMER_NAME = "transformer" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" class StableDiffusionLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). """ _lora_loadable_modules = ["unet", "text_encoder"] unet_name = UNET_NAME text_encoder_name = TEXT_ENCODER_NAME def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into `self.unet`. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded into `self.text_encoder`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_unet( state_dict, network_alphas=network_alphas, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, _pipeline=self, ) self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, text_encoder=getattr(self, self.text_encoder_name) if not hasattr(self, "text_encoder") else self.text_encoder, lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, ) @classmethod @validate_hf_hub_args def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs, ): r""" Return state dict for lora weights and the network alphas. We support loading A1111 formatted LoRA checkpoints in a limited capacity. This function is experimental and might change in the future. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } state_dict = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, proxies=proxies, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, ) network_alphas = None # TODO: replace it with a method from `state_dict_utils` if all( ( k.startswith("lora_te_") or k.startswith("lora_unet_") or k.startswith("lora_te1_") or k.startswith("lora_te2_") ) for k in state_dict.keys() ): # Map SDXL blocks correctly. if unet_config is not None: # use unet config to remap block numbers state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) return state_dict, network_alphas @classmethod def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `unet`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. network_alphas (`Dict[str, float]`): The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) if not only_text_encoder: # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline ) @classmethod def load_lora_into_text_encoder( cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, adapter_name=None, _pipeline=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The key should be prefixed with an additional `text_encoder` to distinguish between unet lora layers. network_alphas (`Dict[str, float]`): See `LoRALinearLayer` for more details. text_encoder (`CLIPTextModel`): The text encoder model to load the LoRA layers into. prefix (`str`): Expected prefix of the `text_encoder` in the `state_dict`. lora_scale (`float`): How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] text_encoder_lora_state_dict = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: logger.info(f"Loading {prefix}.") rank = {} text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) for name, _ in text_encoder_attn_modules(text_encoder): for module in ("out_proj", "q_proj", "k_proj", "v_proj"): rank_key = f"{name}.{module}.lora_B.weight" if rank_key not in text_encoder_lora_state_dict: continue rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] for name, _ in text_encoder_mlp_modules(text_encoder): for module in ("fc1", "fc2"): rank_key = f"{name}.{module}.lora_B.weight" if rank_key not in text_encoder_lora_state_dict: continue rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix ] network_alphas = { k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(text_encoder) is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) # inject LoRA layers and load the state dict # in transformers we automatically check whether the adapter name is already in use or not text_encoder.load_adapter( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, ) # scale LoRA layers with `lora_scale` scale_lora_layers(text_encoder, weight=lora_scale) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() # Unsafe code /> @classmethod def save_lora_weights( cls, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `unet`. text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} if not (unet_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") if unet_lora_layers: state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) # Save the model cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, ) def fuse_lora( self, components: List[str] = ["unet", "text_encoder"], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, **kwargs, ): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. This is an experimental API. Args: components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. safe_fusing (`bool`, defaults to `False`): Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. adapter_names (`List[str]`, *optional*): Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. Example: ```py from diffusers import DiffusionPipeline import torch pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") pipeline.fuse_lora(lora_scale=0.7) ``` """ super().fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). This is an experimental API. Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_text_encoder (`bool`, defaults to `True`): Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into Stable Diffusion XL [`UNet2DConditionModel`], [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection). """ _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"] unet_name = UNET_NAME text_encoder_name = TEXT_ENCODER_NAME def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name: Optional[str] = None, **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into `self.unet`. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded into `self.text_encoder`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs, ) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_unet( state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_state_dict, network_alphas=network_alphas, text_encoder=self.text_encoder, prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, ) text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} if len(text_encoder_2_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=self.text_encoder_2, prefix="text_encoder_2", lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, ) @classmethod @validate_hf_hub_args # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs, ): r""" Return state dict for lora weights and the network alphas. We support loading A1111 formatted LoRA checkpoints in a limited capacity. This function is experimental and might change in the future. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } state_dict = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, proxies=proxies, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, ) network_alphas = None # TODO: replace it with a method from `state_dict_utils` if all( ( k.startswith("lora_te_") or k.startswith("lora_unet_") or k.startswith("lora_te1_") or k.startswith("lora_te2_") ) for k in state_dict.keys() ): # Map SDXL blocks correctly. if unet_config is not None: # use unet config to remap block numbers state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) return state_dict, network_alphas @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `unet`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. network_alphas (`Dict[str, float]`): The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys) if not only_text_encoder: # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, adapter_name=None, _pipeline=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The key should be prefixed with an additional `text_encoder` to distinguish between unet lora layers. network_alphas (`Dict[str, float]`): See `LoRALinearLayer` for more details. text_encoder (`CLIPTextModel`): The text encoder model to load the LoRA layers into. prefix (`str`): Expected prefix of the `text_encoder` in the `state_dict`. lora_scale (`float`): How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] text_encoder_lora_state_dict = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: logger.info(f"Loading {prefix}.") rank = {} text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) for name, _ in text_encoder_attn_modules(text_encoder): for module in ("out_proj", "q_proj", "k_proj", "v_proj"): rank_key = f"{name}.{module}.lora_B.weight" if rank_key not in text_encoder_lora_state_dict: continue rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] for name, _ in text_encoder_mlp_modules(text_encoder): for module in ("fc1", "fc2"): rank_key = f"{name}.{module}.lora_B.weight" if rank_key not in text_encoder_lora_state_dict: continue rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix ] network_alphas = { k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(text_encoder) is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) # inject LoRA layers and load the state dict # in transformers we automatically check whether the adapter name is already in use or not text_encoder.load_adapter( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, ) # scale LoRA layers with `lora_scale` scale_lora_layers(text_encoder, weight=lora_scale) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() # Unsafe code /> @classmethod def save_lora_weights( cls, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `unet`. text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." ) if unet_lora_layers: state_dict.update(cls.pack_weights(unet_lora_layers, "unet")) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, ) def fuse_lora( self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, **kwargs, ): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. This is an experimental API. Args: components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. safe_fusing (`bool`, defaults to `False`): Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. adapter_names (`List[str]`, *optional*): Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. Example: ```py from diffusers import DiffusionPipeline import torch pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") pipeline.fuse_lora(lora_scale=0.7) ``` """ super().fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). This is an experimental API. Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_text_encoder (`bool`, defaults to `True`): Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) class SD3LoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`SD3Transformer2DModel`], [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection). Specific to [`StableDiffusion3Pipeline`]. """ _lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME @classmethod @validate_hf_hub_args def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs, ): r""" Return state dict for lora weights and the network alphas. We support loading A1111 formatted LoRA checkpoints in a limited capacity. This function is experimental and might change in the future. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } state_dict = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, proxies=proxies, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, ) return state_dict def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state dict is loaded into `self.transformer`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_transformer( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_state_dict, network_alphas=None, text_encoder=self.text_encoder, prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, ) text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} if len(text_encoder_2_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_2_state_dict, network_alphas=None, text_encoder=self.text_encoder_2, prefix="text_encoder_2", lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, ) @classmethod def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. transformer (`SD3Transformer2DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] state_dict = { k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys } if len(state_dict.keys()) > 0: # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) if "lora_A" not in first_key: state_dict = convert_unet_state_dict_to_peft(state_dict) if adapter_name in getattr(transformer, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." ) rank = {} for key, val in state_dict.items(): if "lora_B" in key: rank[key] = val.shape[1] lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) else: lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(transformer) # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: logger.warning( f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() # Unsafe code /> @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, adapter_name=None, _pipeline=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The key should be prefixed with an additional `text_encoder` to distinguish between unet lora layers. network_alphas (`Dict[str, float]`): See `LoRALinearLayer` for more details. text_encoder (`CLIPTextModel`): The text encoder model to load the LoRA layers into. prefix (`str`): Expected prefix of the `text_encoder` in the `state_dict`. lora_scale (`float`): How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] text_encoder_lora_state_dict = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: logger.info(f"Loading {prefix}.") rank = {} text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) for name, _ in text_encoder_attn_modules(text_encoder): for module in ("out_proj", "q_proj", "k_proj", "v_proj"): rank_key = f"{name}.{module}.lora_B.weight" if rank_key not in text_encoder_lora_state_dict: continue rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] for name, _ in text_encoder_mlp_modules(text_encoder): for module in ("fc1", "fc2"): rank_key = f"{name}.{module}.lora_B.weight" if rank_key not in text_encoder_lora_state_dict: continue rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix ] network_alphas = { k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(text_encoder) is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) # inject LoRA layers and load the state dict # in transformers we automatically check whether the adapter name is already in use or not text_encoder.load_adapter( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, ) # scale LoRA layers with `lora_scale` scale_lora_layers(text_encoder, weight=lora_scale) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() # Unsafe code /> @classmethod def save_lora_weights( cls, save_directory: Union[str, os.PathLike], transformer_lora_layers: Dict[str, torch.nn.Module] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `transformer`. text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): raise ValueError( "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." ) if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) if text_encoder_2_lora_layers: state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) # Save the model cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, ) def fuse_lora( self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, **kwargs, ): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. This is an experimental API. Args: components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. safe_fusing (`bool`, defaults to `False`): Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. adapter_names (`List[str]`, *optional*): Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. Example: ```py from diffusers import DiffusionPipeline import torch pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") pipeline.fuse_lora(lora_scale=0.7) ``` """ super().fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). This is an experimental API. Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_text_encoder (`bool`, defaults to `True`): Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) class FluxLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`FluxTransformer2DModel`], [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). Specific to [`StableDiffusion3Pipeline`]. """ _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME @classmethod @validate_hf_hub_args # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs, ): r""" Return state dict for lora weights and the network alphas. We support loading A1111 formatted LoRA checkpoints in a limited capacity. This function is experimental and might change in the future. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } state_dict = cls._fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, proxies=proxies, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, ) return state_dict def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state dict is loaded into `self.transformer`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_transformer( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_state_dict, network_alphas=None, text_encoder=self.text_encoder, prefix="text_encoder", lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. transformer (`SD3Transformer2DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] state_dict = { k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys } if len(state_dict.keys()) > 0: # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) if "lora_A" not in first_key: state_dict = convert_unet_state_dict_to_peft(state_dict) if adapter_name in getattr(transformer, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." ) rank = {} for key, val in state_dict.items(): if "lora_B" in key: rank[key] = val.shape[1] lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) else: lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(transformer) # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: logger.warning( f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() # Unsafe code /> @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, adapter_name=None, _pipeline=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The key should be prefixed with an additional `text_encoder` to distinguish between unet lora layers. network_alphas (`Dict[str, float]`): See `LoRALinearLayer` for more details. text_encoder (`CLIPTextModel`): The text encoder model to load the LoRA layers into. prefix (`str`): Expected prefix of the `text_encoder` in the `state_dict`. lora_scale (`float`): How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] text_encoder_lora_state_dict = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: logger.info(f"Loading {prefix}.") rank = {} text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) for name, _ in text_encoder_attn_modules(text_encoder): for module in ("out_proj", "q_proj", "k_proj", "v_proj"): rank_key = f"{name}.{module}.lora_B.weight" if rank_key not in text_encoder_lora_state_dict: continue rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] for name, _ in text_encoder_mlp_modules(text_encoder): for module in ("fc1", "fc2"): rank_key = f"{name}.{module}.lora_B.weight" if rank_key not in text_encoder_lora_state_dict: continue rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix ] network_alphas = { k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(text_encoder) is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) # inject LoRA layers and load the state dict # in transformers we automatically check whether the adapter name is already in use or not text_encoder.load_adapter( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, ) # scale LoRA layers with `lora_scale` scale_lora_layers(text_encoder, weight=lora_scale) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() # Unsafe code /> @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer def save_lora_weights( cls, save_directory: Union[str, os.PathLike], transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `transformer`. text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} if not (transformer_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) # Save the model cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, ) # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, components: List[str] = ["transformer", "text_encoder"], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, **kwargs, ): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. This is an experimental API. Args: components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. safe_fusing (`bool`, defaults to `False`): Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. adapter_names (`List[str]`, *optional*): Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. Example: ```py from diffusers import DiffusionPipeline import torch pipeline = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") pipeline.fuse_lora(lora_scale=0.7) ``` """ super().fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). This is an experimental API. Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. """ super().unfuse_lora(components=components) # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME @classmethod def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. network_alphas (`Dict[str, float]`): See `LoRALinearLayer` for more details. unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] state_dict = { k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys } if network_alphas is not None: alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)] network_alphas = { k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } if len(state_dict.keys()) > 0: if adapter_name in getattr(transformer, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." ) rank = {} for key, val in state_dict.items(): if "lora_B" in key: rank[key] = val.shape[1] lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) else: lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(transformer) # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: logger.warning( f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() # Unsafe code /> @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, adapter_name=None, _pipeline=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The key should be prefixed with an additional `text_encoder` to distinguish between unet lora layers. network_alphas (`Dict[str, float]`): See `LoRALinearLayer` for more details. text_encoder (`CLIPTextModel`): The text encoder model to load the LoRA layers into. prefix (`str`): Expected prefix of the `text_encoder` in the `state_dict`. lora_scale (`float`): How much to scale the output of the lora linear layer before it is added with the output of the regular lora layer. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] text_encoder_lora_state_dict = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: logger.info(f"Loading {prefix}.") rank = {} text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) for name, _ in text_encoder_attn_modules(text_encoder): for module in ("out_proj", "q_proj", "k_proj", "v_proj"): rank_key = f"{name}.{module}.lora_B.weight" if rank_key not in text_encoder_lora_state_dict: continue rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] for name, _ in text_encoder_mlp_modules(text_encoder): for module in ("fc1", "fc2"): rank_key = f"{name}.{module}.lora_B.weight" if rank_key not in text_encoder_lora_state_dict: continue rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix ] network_alphas = { k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(text_encoder) is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) # inject LoRA layers and load the state dict # in transformers we automatically check whether the adapter name is already in use or not text_encoder.load_adapter( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, ) # scale LoRA layers with `lora_scale` scale_lora_layers(text_encoder, weight=lora_scale) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) # Offload back. if is_model_cpu_offload: _pipeline.enable_model_cpu_offload() elif is_sequential_cpu_offload: _pipeline.enable_sequential_cpu_offload() # Unsafe code /> @classmethod def save_lora_weights( cls, save_directory: Union[str, os.PathLike], text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, transformer_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `unet`. text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ state_dict = {} if not (transformer_lora_layers or text_encoder_lora_layers): raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if text_encoder_lora_layers: state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) # Save the model cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, ) class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) super().__init__(*args, **kwargs) ================================================ FILE: src/diffusers/loaders/peft.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from functools import partial from typing import Dict, List, Optional, Union from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, delete_adapter_layers, is_peft_available, set_adapter_layers, set_weights_and_activate_adapters, ) from .unet_loader_utils import _maybe_expand_lora_scales _SET_ADAPTER_SCALE_FN_MAPPING = { "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, "SD3Transformer2DModel": lambda model_cls, weights: weights, "FluxTransformer2DModel": lambda model_cls, weights: weights, } class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For more details about adapters and injecting them in a base model, check out the PEFT [documentation](https://huggingface.co/docs/peft/index). Install the latest version of PEFT, and use this mixin to: - Attach new adapters in the model. - Attach multiple adapters and iteratively activate/deactivate them. - Activate/deactivate all adapters from the model. - Get a list of the active adapters. """ _hf_peft_config_loaded = False def set_adapters( self, adapter_names: Union[List[str], str], weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None, ): """ Set the currently active adapters for use in the UNet. Args: adapter_names (`List[str]` or `str`): The names of the adapters to use. adapter_weights (`Union[List[float], float]`, *optional*): The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the adapters. Example: ```py from diffusers import AutoPipelineForText2Image import torch pipeline = AutoPipelineForText2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5]) ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `set_adapters()`.") adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names # Expand weights into a list, one entry per adapter # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None] if not isinstance(weights, list): weights = [weights] * len(adapter_names) if len(adapter_names) != len(weights): raise ValueError( f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}." ) # Set None values to default of 1.0 # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0] weights = [w if w is not None else 1.0 for w in weights] # e.g. [{...}, 7] -> [{expanded dict...}, 7] scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__] weights = scale_expansion_fn(self, weights) set_weights_and_activate_adapters(self, adapter_names, weights) def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: r""" Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned to the adapter to follow the convention of the PEFT library. If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT [documentation](https://huggingface.co/docs/peft). Args: adapter_config (`[~peft.PeftConfig]`): The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt methods. adapter_name (`str`, *optional*, defaults to `"default"`): The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. """ check_peft_version(min_version=MIN_PEFT_VERSION) if not is_peft_available(): raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") from peft import PeftConfig, inject_adapter_in_model if not self._hf_peft_config_loaded: self._hf_peft_config_loaded = True elif adapter_name in self.peft_config: raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") if not isinstance(adapter_config, PeftConfig): raise ValueError( f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." ) # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is # handled by the `load_lora_layers` or `StableDiffusionLoraLoaderMixin`. Therefore we set it to `None` here. adapter_config.base_model_name_or_path = None inject_adapter_in_model(adapter_config, self, adapter_name) self.set_adapter(adapter_name) def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: """ Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT [documentation](https://huggingface.co/docs/peft). Args: adapter_name (Union[str, List[str]])): The list of adapters to set or the adapter name in the case of a single adapter. """ check_peft_version(min_version=MIN_PEFT_VERSION) if not self._hf_peft_config_loaded: raise ValueError("No adapter loaded. Please load an adapter first.") if isinstance(adapter_name, str): adapter_name = [adapter_name] missing = set(adapter_name) - set(self.peft_config) if len(missing) > 0: raise ValueError( f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." f" current loaded adapters are: {list(self.peft_config.keys())}" ) from peft.tuners.tuners_utils import BaseTunerLayer _adapters_has_been_set = False for _, module in self.named_modules(): if isinstance(module, BaseTunerLayer): if hasattr(module, "set_adapter"): module.set_adapter(adapter_name) # Previous versions of PEFT does not support multi-adapter inference elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: raise ValueError( "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" ) else: module.active_adapter = adapter_name _adapters_has_been_set = True if not _adapters_has_been_set: raise ValueError( "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." ) def disable_adapters(self) -> None: r""" Disable all adapters attached to the model and fallback to inference with the base model only. If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT [documentation](https://huggingface.co/docs/peft). """ check_peft_version(min_version=MIN_PEFT_VERSION) if not self._hf_peft_config_loaded: raise ValueError("No adapter loaded. Please load an adapter first.") from peft.tuners.tuners_utils import BaseTunerLayer for _, module in self.named_modules(): if isinstance(module, BaseTunerLayer): if hasattr(module, "enable_adapters"): module.enable_adapters(enabled=False) else: # support for older PEFT versions module.disable_adapters = True def enable_adapters(self) -> None: """ Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of adapters to enable. If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT [documentation](https://huggingface.co/docs/peft). """ check_peft_version(min_version=MIN_PEFT_VERSION) if not self._hf_peft_config_loaded: raise ValueError("No adapter loaded. Please load an adapter first.") from peft.tuners.tuners_utils import BaseTunerLayer for _, module in self.named_modules(): if isinstance(module, BaseTunerLayer): if hasattr(module, "enable_adapters"): module.enable_adapters(enabled=True) else: # support for older PEFT versions module.disable_adapters = False def active_adapters(self) -> List[str]: """ Gets the current list of active adapters of the model. If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT [documentation](https://huggingface.co/docs/peft). """ check_peft_version(min_version=MIN_PEFT_VERSION) if not is_peft_available(): raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") if not self._hf_peft_config_loaded: raise ValueError("No adapter loaded. Please load an adapter first.") from peft.tuners.tuners_utils import BaseTunerLayer for _, module in self.named_modules(): if isinstance(module, BaseTunerLayer): return module.active_adapter def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `fuse_lora()`.") self.lora_scale = lora_scale self._safe_fusing = safe_fusing self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names)) def _fuse_lora_apply(self, module, adapter_names=None): from peft.tuners.tuners_utils import BaseTunerLayer merge_kwargs = {"safe_merge": self._safe_fusing} if isinstance(module, BaseTunerLayer): if self.lora_scale != 1.0: module.scale_layer(self.lora_scale) # For BC with prevous PEFT versions, we need to check the signature # of the `merge` method to see if it supports the `adapter_names` argument. supported_merge_kwargs = list(inspect.signature(module.merge).parameters) if "adapter_names" in supported_merge_kwargs: merge_kwargs["adapter_names"] = adapter_names elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: raise ValueError( "The `adapter_names` argument is not supported with your PEFT version. Please upgrade" " to the latest version of PEFT. `pip install -U peft`" ) module.merge(**merge_kwargs) def unfuse_lora(self): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `unfuse_lora()`.") self.apply(self._unfuse_lora_apply) def _unfuse_lora_apply(self, module): from peft.tuners.tuners_utils import BaseTunerLayer if isinstance(module, BaseTunerLayer): module.unmerge() def unload_lora(self): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for `unload_lora()`.") from ..utils import recurse_remove_peft_layers recurse_remove_peft_layers(self) if hasattr(self, "peft_config"): del self.peft_config def disable_lora(self): """ Disables the active LoRA layers of the underlying model. Example: ```py from diffusers import AutoPipelineForText2Image import torch pipeline = AutoPipelineForText2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) pipeline.disable_lora() ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") set_adapter_layers(self, enabled=False) def enable_lora(self): """ Enables the active LoRA layers of the underlying model. Example: ```py from diffusers import AutoPipelineForText2Image import torch pipeline = AutoPipelineForText2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) pipeline.enable_lora() ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") set_adapter_layers(self, enabled=True) def delete_adapters(self, adapter_names: Union[List[str], str]): """ Delete an adapter's LoRA layers from the underlying model. Args: adapter_names (`Union[List[str], str]`): The names (single string or list of strings) of the adapter to delete. Example: ```py from diffusers import AutoPipelineForText2Image import torch pipeline = AutoPipelineForText2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.load_lora_weights( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" ) pipeline.delete_adapters("cinematic") ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") if isinstance(adapter_names, str): adapter_names = [adapter_names] for adapter_name in adapter_names: delete_adapter_layers(self, adapter_name) # Pop also the corresponding adapter from the config if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) ================================================ FILE: src/diffusers/loaders/single_file.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect import os import torch from huggingface_hub import snapshot_download from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args from packaging import version from ..utils import deprecate, is_transformers_available, logging from .single_file_utils import ( SingleFileComponentError, _is_model_weights_in_cached_folder, _legacy_load_clip_tokenizer, _legacy_load_safety_checker, _legacy_load_scheduler, create_diffusers_clip_model_from_ldm, create_diffusers_t5_model_from_checkpoint, fetch_diffusers_config, fetch_original_config, is_clip_model_in_single_file, is_t5_in_single_file, load_single_file_checkpoint, ) logger = logging.get_logger(__name__) # Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"] if is_transformers_available(): import transformers from transformers import PreTrainedModel, PreTrainedTokenizer def load_single_file_sub_model( library_name, class_name, name, checkpoint, pipelines, is_pipeline_module, cached_model_config_path, original_config=None, local_files_only=False, torch_dtype=None, is_legacy_loading=False, **kwargs, ): if is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) else: # else we just import it from the library. library = importlib.import_module(library_name) class_obj = getattr(library, class_name) if is_transformers_available(): transformers_version = version.parse(version.parse(transformers.__version__).base_version) else: transformers_version = "N/A" is_transformers_model = ( is_transformers_available() and issubclass(class_obj, PreTrainedModel) and transformers_version >= version.parse("4.20.0") ) is_tokenizer = ( is_transformers_available() and issubclass(class_obj, PreTrainedTokenizer) and transformers_version >= version.parse("4.20.0") ) diffusers_module = importlib.import_module(__name__.split(".")[0]) is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin) is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin) if is_diffusers_single_file_model: load_method = getattr(class_obj, "from_single_file") # We cannot provide two different config options to the `from_single_file` method # Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided if original_config: cached_model_config_path = None loaded_sub_model = load_method( pretrained_model_link_or_path_or_dict=checkpoint, original_config=original_config, config=cached_model_config_path, subfolder=name, torch_dtype=torch_dtype, local_files_only=local_files_only, **kwargs, ) elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint): loaded_sub_model = create_diffusers_clip_model_from_ldm( class_obj, checkpoint=checkpoint, config=cached_model_config_path, subfolder=name, torch_dtype=torch_dtype, local_files_only=local_files_only, is_legacy_loading=is_legacy_loading, ) elif is_transformers_model and is_t5_in_single_file(checkpoint): loaded_sub_model = create_diffusers_t5_model_from_checkpoint( class_obj, checkpoint=checkpoint, config=cached_model_config_path, subfolder=name, torch_dtype=torch_dtype, local_files_only=local_files_only, ) elif is_tokenizer and is_legacy_loading: loaded_sub_model = _legacy_load_clip_tokenizer( class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only ) elif is_diffusers_scheduler and is_legacy_loading: loaded_sub_model = _legacy_load_scheduler( class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs ) else: if not hasattr(class_obj, "from_pretrained"): raise ValueError( ( f"The component {class_obj.__name__} cannot be loaded as it does not seem to have" " a supported loading method." ) ) loading_kwargs = {} loading_kwargs.update( { "pretrained_model_name_or_path": cached_model_config_path, "subfolder": name, "local_files_only": local_files_only, } ) # Schedulers and Tokenizers don't make use of torch_dtype # Skip passing it to those objects if issubclass(class_obj, torch.nn.Module): loading_kwargs.update({"torch_dtype": torch_dtype}) if is_diffusers_model or is_transformers_model: if not _is_model_weights_in_cached_folder(cached_model_config_path, name): raise SingleFileComponentError( f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint." ) load_method = getattr(class_obj, "from_pretrained") loaded_sub_model = load_method(**loading_kwargs) return loaded_sub_model def _map_component_types_to_config_dict(component_types): diffusers_module = importlib.import_module(__name__.split(".")[0]) config_dict = {} component_types.pop("self", None) if is_transformers_available(): transformers_version = version.parse(version.parse(transformers.__version__).base_version) else: transformers_version = "N/A" for component_name, component_value in component_types.items(): is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin) is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers" is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin) is_transformers_model = ( is_transformers_available() and issubclass(component_value[0], PreTrainedModel) and transformers_version >= version.parse("4.20.0") ) is_transformers_tokenizer = ( is_transformers_available() and issubclass(component_value[0], PreTrainedTokenizer) and transformers_version >= version.parse("4.20.0") ) if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS: config_dict[component_name] = ["diffusers", component_value[0].__name__] elif is_scheduler_enum or is_scheduler: if is_scheduler_enum: # Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler # if the type hint is a KarrassDiffusionSchedulers enum config_dict[component_name] = ["diffusers", "DDIMScheduler"] elif is_scheduler: config_dict[component_name] = ["diffusers", component_value[0].__name__] elif ( is_transformers_model or is_transformers_tokenizer ) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS: config_dict[component_name] = ["transformers", component_value[0].__name__] else: config_dict[component_name] = [None, None] return config_dict def _infer_pipeline_config_dict(pipeline_class): parameters = inspect.signature(pipeline_class.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} component_types = pipeline_class._get_signature_types() # Ignore parameters that are not required for the pipeline component_types = {k: v for k, v in component_types.items() if k in required_parameters} config_dict = _map_component_types_to_config_dict(component_types) return config_dict def _download_diffusers_model_config_from_hub( pretrained_model_name_or_path, cache_dir, revision, proxies, force_download=None, local_files_only=None, token=None, ): allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"] cached_model_path = snapshot_download( pretrained_model_name_or_path, cache_dir=cache_dir, revision=revision, proxies=proxies, force_download=force_download, local_files_only=local_files_only, token=token, allow_patterns=allow_patterns, ) return cached_model_path class FromSingleFileMixin: """ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. """ @classmethod @validate_hf_hub_args def from_single_file(cls, pretrained_model_link_or_path, **kwargs): r""" Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. Parameters: pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A link to the `.ckpt` file (for example `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - A path to a *file* containing all pipeline weights. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. original_config_file (`str`, *optional*): The path to the original config file that was used to train the model. If not provided, the config file will be inferred from the checkpoint file. config (`str`, *optional*): Can be either: - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component configs in Diffusers format. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example below for more information. Examples: ```py >>> from diffusers import StableDiffusionPipeline >>> # Download pipeline from huggingface.co and cache. >>> pipeline = StableDiffusionPipeline.from_single_file( ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" ... ) >>> # Download pipeline from local file >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt") >>> # Enable float16 and move to GPU >>> pipeline = StableDiffusionPipeline.from_single_file( ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", ... torch_dtype=torch.float16, ... ) >>> pipeline.to("cuda") ``` """ original_config_file = kwargs.pop("original_config_file", None) config = kwargs.pop("config", None) original_config = kwargs.pop("original_config", None) if original_config_file is not None: deprecation_message = ( "`original_config_file` argument is deprecated and will be removed in future versions." "please use the `original_config` argument instead." ) deprecate("original_config_file", "1.0.0", deprecation_message) original_config = original_config_file force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) token = kwargs.pop("token", None) cache_dir = kwargs.pop("cache_dir", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) is_legacy_loading = False # We shouldn't allow configuring individual models components through a Pipeline creation method # These model kwargs should be deprecated scaling_factor = kwargs.get("scaling_factor", None) if scaling_factor is not None: deprecation_message = ( "Passing the `scaling_factor` argument to `from_single_file is deprecated " "and will be ignored in future versions." ) deprecate("scaling_factor", "1.0.0", deprecation_message) if original_config is not None: original_config = fetch_original_config(original_config, local_files_only=local_files_only) from ..pipelines.pipeline_utils import _get_pipeline_class pipeline_class = _get_pipeline_class(cls, config=None) checkpoint = load_single_file_checkpoint( pretrained_model_link_or_path, force_download=force_download, proxies=proxies, token=token, cache_dir=cache_dir, local_files_only=local_files_only, revision=revision, ) if config is None: config = fetch_diffusers_config(checkpoint) default_pretrained_model_config_name = config["pretrained_model_name_or_path"] else: default_pretrained_model_config_name = config if not os.path.isdir(default_pretrained_model_config_name): # Provided config is a repo_id if default_pretrained_model_config_name.count("/") > 1: raise ValueError( f'The provided config "{config}"' " is neither a valid local path nor a valid repo id. Please check the parameter." ) try: # Attempt to download the config files for the pipeline cached_model_config_path = _download_diffusers_model_config_from_hub( default_pretrained_model_config_name, cache_dir=cache_dir, revision=revision, proxies=proxies, force_download=force_download, local_files_only=local_files_only, token=token, ) config_dict = pipeline_class.load_config(cached_model_config_path) except LocalEntryNotFoundError: # `local_files_only=True` but a local diffusers format model config is not available in the cache # If `original_config` is not provided, we need override `local_files_only` to False # to fetch the config files from the hub so that we have a way # to configure the pipeline components. if original_config is None: logger.warning( "`local_files_only` is True but no local configs were found for this checkpoint.\n" "Attempting to download the necessary config files for this pipeline.\n" ) cached_model_config_path = _download_diffusers_model_config_from_hub( default_pretrained_model_config_name, cache_dir=cache_dir, revision=revision, proxies=proxies, force_download=force_download, local_files_only=False, token=token, ) config_dict = pipeline_class.load_config(cached_model_config_path) else: # For backwards compatibility # If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components logger.warning( "Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n" "This may lead to errors if the model components are not correctly inferred. \n" "To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n" "e.g. `from_single_file(, config=) \n" "or run `from_single_file` with `local_files_only=False` first to update the local cache directory with " "the necessary config files.\n" ) is_legacy_loading = True cached_model_config_path = None config_dict = _infer_pipeline_config_dict(pipeline_class) config_dict["_class_name"] = pipeline_class.__name__ else: # Provided config is a path to a local directory attempt to load directly. cached_model_config_path = default_pretrained_model_config_name config_dict = pipeline_class.load_config(cached_model_config_path) # pop out "_ignore_files" as it is only needed for download config_dict.pop("_ignore_files", None) expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} init_kwargs = {**init_kwargs, **passed_pipe_kwargs} from diffusers import pipelines # remove `null` components def load_module(name, value): if value[0] is None: return False if name in passed_class_obj and passed_class_obj[name] is None: return False if name in SINGLE_FILE_OPTIONAL_COMPONENTS: return False return True init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} for name, (library_name, class_name) in logging.tqdm( sorted(init_dict.items()), desc="Loading pipeline components..." ): loaded_sub_model = None is_pipeline_module = hasattr(pipelines, library_name) if name in passed_class_obj: loaded_sub_model = passed_class_obj[name] else: try: loaded_sub_model = load_single_file_sub_model( library_name=library_name, class_name=class_name, name=name, checkpoint=checkpoint, is_pipeline_module=is_pipeline_module, cached_model_config_path=cached_model_config_path, pipelines=pipelines, torch_dtype=torch_dtype, original_config=original_config, local_files_only=local_files_only, is_legacy_loading=is_legacy_loading, **kwargs, ) except SingleFileComponentError as e: raise SingleFileComponentError( ( f"{e.message}\n" f"Please load the component before passing it in as an argument to `from_single_file`.\n" f"\n" f"{name} = {class_name}.from_pretrained('...')\n" f"pipe = {pipeline_class.__name__}.from_single_file(, {name}={name})\n" f"\n" ) ) init_kwargs[name] = loaded_sub_model missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) optional_modules = pipeline_class._optional_components if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): for module in missing_modules: init_kwargs[module] = passed_class_obj.get(module, None) elif len(missing_modules) > 0: passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs raise ValueError( f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) # deprecated kwargs load_safety_checker = kwargs.pop("load_safety_checker", None) if load_safety_checker is not None: deprecation_message = ( "Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`" "using the `safety_checker` and `feature_extractor` arguments in `from_single_file`" ) deprecate("load_safety_checker", "1.0.0", deprecation_message) safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype) init_kwargs.update(safety_checker_components) pipe = pipeline_class(**init_kwargs) return pipe ================================================ FILE: src/diffusers/loaders/single_file_model.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect import re from contextlib import nullcontext from typing import Optional from huggingface_hub.utils import validate_hf_hub_args from ..utils import deprecate, is_accelerate_available, logging from .single_file_utils import ( SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_flux_transformer_checkpoint_to_diffusers, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, create_controlnet_diffusers_config_from_ldm, create_unet_diffusers_config_from_ldm, create_vae_diffusers_config_from_ldm, fetch_diffusers_config, fetch_original_config, load_single_file_checkpoint, ) logger = logging.get_logger(__name__) if is_accelerate_available(): from accelerate import init_empty_weights from ..models.modeling_utils import load_model_dict_into_meta SINGLE_FILE_LOADABLE_CLASSES = { "StableCascadeUNet": { "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers, }, "UNet2DConditionModel": { "checkpoint_mapping_fn": convert_ldm_unet_checkpoint, "config_mapping_fn": create_unet_diffusers_config_from_ldm, "default_subfolder": "unet", "legacy_kwargs": { "num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args }, }, "AutoencoderKL": { "checkpoint_mapping_fn": convert_ldm_vae_checkpoint, "config_mapping_fn": create_vae_diffusers_config_from_ldm, "default_subfolder": "vae", }, "ControlNetModel": { "checkpoint_mapping_fn": convert_controlnet_checkpoint, "config_mapping_fn": create_controlnet_diffusers_config_from_ldm, }, "SD3Transformer2DModel": { "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, "MotionAdapter": { "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, }, "SparseControlNetModel": { "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, }, "FluxTransformer2DModel": { "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, } def _get_single_file_loadable_mapping_class(cls): diffusers_module = importlib.import_module(__name__.split(".")[0]) for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: loadable_class = getattr(diffusers_module, loadable_class_str) if issubclass(cls, loadable_class): return loadable_class_str return None def _get_mapping_function_kwargs(mapping_fn, **kwargs): parameters = inspect.signature(mapping_fn).parameters mapping_kwargs = {} for parameter in parameters: if parameter in kwargs: mapping_kwargs[parameter] = kwargs[parameter] return mapping_kwargs class FromOriginalModelMixin: """ Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model. """ @classmethod @validate_hf_hub_args def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs): r""" Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model is set in evaluation mode (`model.eval()`) by default. Parameters: pretrained_model_link_or_path_or_dict (`str`, *optional*): Can be either: - A link to the `.safetensors` or `.ckpt` file (for example `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. - A path to a local *file* containing the weights of the component model. - A state dict containing the component model weights. config (`str`, *optional*): - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component configs in Diffusers format. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. original_config (`str`, *optional*): Dict or path to a yaml file containing the configuration for the model in its original format. If a dict is provided, it will be used to initialize the model configuration. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to True, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (for example the pipeline components of the specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` method. See example below for more information. ```py >>> from diffusers import StableCascadeUNet >>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors" >>> model = StableCascadeUNet.from_single_file(ckpt_path) ``` """ mapping_class_name = _get_single_file_loadable_mapping_class(cls) # if class_name not in SINGLE_FILE_LOADABLE_CLASSES: if mapping_class_name is None: raise ValueError( f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}" ) pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None) if pretrained_model_link_or_path is not None: deprecation_message = ( "Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes" ) deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message) pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path config = kwargs.pop("config", None) original_config = kwargs.pop("original_config", None) if config is not None and original_config is not None: raise ValueError( "`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments" ) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) token = kwargs.pop("token", None) cache_dir = kwargs.pop("cache_dir", None) local_files_only = kwargs.pop("local_files_only", None) subfolder = kwargs.pop("subfolder", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) if isinstance(pretrained_model_link_or_path_or_dict, dict): checkpoint = pretrained_model_link_or_path_or_dict else: checkpoint = load_single_file_checkpoint( pretrained_model_link_or_path_or_dict, force_download=force_download, proxies=proxies, token=token, cache_dir=cache_dir, local_files_only=local_files_only, revision=revision, ) mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] if original_config: if "config_mapping_fn" in mapping_functions: config_mapping_fn = mapping_functions["config_mapping_fn"] else: config_mapping_fn = None if config_mapping_fn is None: raise ValueError( ( f"`original_config` has been provided for {mapping_class_name} but no mapping function" "was found to convert the original config to a Diffusers config in" "`diffusers.loaders.single_file_utils`" ) ) if isinstance(original_config, str): # If original_config is a URL or filepath fetch the original_config dict original_config = fetch_original_config(original_config, local_files_only=local_files_only) config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs) diffusers_model_config = config_mapping_fn( original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs ) else: if config: if isinstance(config, str): default_pretrained_model_config_name = config else: raise ValueError( ( "Invalid `config` argument. Please provide a string representing a repo id" "or path to a local Diffusers model repo." ) ) else: config = fetch_diffusers_config(checkpoint) default_pretrained_model_config_name = config["pretrained_model_name_or_path"] if "default_subfolder" in mapping_functions: subfolder = mapping_functions["default_subfolder"] subfolder = subfolder or config.pop( "subfolder", None ) # some configs contain a subfolder key, e.g. StableCascadeUNet diffusers_model_config = cls.load_config( pretrained_model_name_or_path=default_pretrained_model_config_name, subfolder=subfolder, local_files_only=local_files_only, ) expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) # Map legacy kwargs to new kwargs if "legacy_kwargs" in mapping_functions: legacy_kwargs = mapping_functions["legacy_kwargs"] for legacy_key, new_key in legacy_kwargs.items(): if legacy_key in kwargs: kwargs[new_key] = kwargs.pop(legacy_key) model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} diffusers_model_config.update(model_kwargs) checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) diffusers_format_checkpoint = checkpoint_mapping_fn( config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs ) if not diffusers_format_checkpoint: raise SingleFileComponentError( f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." ) ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): model = cls.from_config(diffusers_model_config) if is_accelerate_available(): unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) else: _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) if model._keys_to_ignore_on_load_unexpected is not None: for pat in model._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) if torch_dtype is not None: model.to(torch_dtype) model.eval() return model ================================================ FILE: src/diffusers/loaders/single_file_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Conversion script for the Stable Diffusion checkpoints.""" import os import re from contextlib import nullcontext from io import BytesIO from urllib.parse import urlparse import requests import torch import yaml from ..models.modeling_utils import load_state_dict from ..schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EDMDPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from ..utils import ( SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, deprecate, is_accelerate_available, is_transformers_available, logging, ) from ..utils.hub_utils import _get_model_file if is_transformers_available(): from transformers import AutoImageProcessor if is_accelerate_available(): from accelerate import init_empty_weights from ..models.modeling_utils import load_model_dict_into_meta logger = logging.get_logger(__name__) # pylint: disable=invalid-name CHECKPOINT_KEY_NAMES = { "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias", "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias", "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias", "controlnet": "control_model.time_embed.0.weight", "playground-v2-5": "edm_mean", "inpainting": "model.diffusion_model.input_blocks.0.0.weight", "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight", "clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight", "open_clip": "cond_stage_model.model.token_embedding.weight", "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding", "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection", "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight", "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", "stable_cascade_stage_c": "clip_txt_mapper.weight", "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", "animatediff_rgb": "controlnet_cond_embedding.weight", "flux": "double_blocks.0.img_attn.norm.key_norm.scale", } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { "xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"}, "xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"}, "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"}, "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"}, "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"}, "inpainting": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-inpainting"}, "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"}, "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"}, "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"}, "v1": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5"}, "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"}, "stable_cascade_stage_b_lite": { "pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder_lite", }, "stable_cascade_stage_c": { "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", "subfolder": "prior", }, "stable_cascade_stage_c_lite": { "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior", "subfolder": "prior_lite", }, "sd3": { "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers", }, "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"}, "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"}, "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"}, "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"}, "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"}, "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, } # Use to configure model sample size when original config is provided DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = { "xl_base": 1024, "xl_refiner": 1024, "xl_inpaint": 1024, "playground-v2-5": 1024, "upscale": 512, "inpainting": 512, "inpainting_v2": 512, "controlnet": 512, "v2": 768, "v1": 512, } DIFFUSERS_TO_LDM_MAPPING = { "unet": { "layers": { "time_embedding.linear_1.weight": "time_embed.0.weight", "time_embedding.linear_1.bias": "time_embed.0.bias", "time_embedding.linear_2.weight": "time_embed.2.weight", "time_embedding.linear_2.bias": "time_embed.2.bias", "conv_in.weight": "input_blocks.0.0.weight", "conv_in.bias": "input_blocks.0.0.bias", "conv_norm_out.weight": "out.0.weight", "conv_norm_out.bias": "out.0.bias", "conv_out.weight": "out.2.weight", "conv_out.bias": "out.2.bias", }, "class_embed_type": { "class_embedding.linear_1.weight": "label_emb.0.0.weight", "class_embedding.linear_1.bias": "label_emb.0.0.bias", "class_embedding.linear_2.weight": "label_emb.0.2.weight", "class_embedding.linear_2.bias": "label_emb.0.2.bias", }, "addition_embed_type": { "add_embedding.linear_1.weight": "label_emb.0.0.weight", "add_embedding.linear_1.bias": "label_emb.0.0.bias", "add_embedding.linear_2.weight": "label_emb.0.2.weight", "add_embedding.linear_2.bias": "label_emb.0.2.bias", }, }, "controlnet": { "layers": { "time_embedding.linear_1.weight": "time_embed.0.weight", "time_embedding.linear_1.bias": "time_embed.0.bias", "time_embedding.linear_2.weight": "time_embed.2.weight", "time_embedding.linear_2.bias": "time_embed.2.bias", "conv_in.weight": "input_blocks.0.0.weight", "conv_in.bias": "input_blocks.0.0.bias", "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight", "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias", "controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight", "controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias", }, "class_embed_type": { "class_embedding.linear_1.weight": "label_emb.0.0.weight", "class_embedding.linear_1.bias": "label_emb.0.0.bias", "class_embedding.linear_2.weight": "label_emb.0.2.weight", "class_embedding.linear_2.bias": "label_emb.0.2.bias", }, "addition_embed_type": { "add_embedding.linear_1.weight": "label_emb.0.0.weight", "add_embedding.linear_1.bias": "label_emb.0.0.bias", "add_embedding.linear_2.weight": "label_emb.0.2.weight", "add_embedding.linear_2.bias": "label_emb.0.2.bias", }, }, "vae": { "encoder.conv_in.weight": "encoder.conv_in.weight", "encoder.conv_in.bias": "encoder.conv_in.bias", "encoder.conv_out.weight": "encoder.conv_out.weight", "encoder.conv_out.bias": "encoder.conv_out.bias", "encoder.conv_norm_out.weight": "encoder.norm_out.weight", "encoder.conv_norm_out.bias": "encoder.norm_out.bias", "decoder.conv_in.weight": "decoder.conv_in.weight", "decoder.conv_in.bias": "decoder.conv_in.bias", "decoder.conv_out.weight": "decoder.conv_out.weight", "decoder.conv_out.bias": "decoder.conv_out.bias", "decoder.conv_norm_out.weight": "decoder.norm_out.weight", "decoder.conv_norm_out.bias": "decoder.norm_out.bias", "quant_conv.weight": "quant_conv.weight", "quant_conv.bias": "quant_conv.bias", "post_quant_conv.weight": "post_quant_conv.weight", "post_quant_conv.bias": "post_quant_conv.bias", }, "openclip": { "layers": { "text_model.embeddings.position_embedding.weight": "positional_embedding", "text_model.embeddings.token_embedding.weight": "token_embedding.weight", "text_model.final_layer_norm.weight": "ln_final.weight", "text_model.final_layer_norm.bias": "ln_final.bias", "text_projection.weight": "text_projection", }, "transformer": { "text_model.encoder.layers.": "resblocks.", "layer_norm1": "ln_1", "layer_norm2": "ln_2", ".fc1.": ".c_fc.", ".fc2.": ".c_proj.", ".self_attn": ".attn", "transformer.text_model.final_layer_norm.": "ln_final.", "transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", "transformer.text_model.embeddings.position_embedding.weight": "positional_embedding", }, }, } SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias", "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight", "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias", "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight", "cond_stage_model.model.transformer.resblocks.23.ln_1.bias", "cond_stage_model.model.transformer.resblocks.23.ln_1.weight", "cond_stage_model.model.transformer.resblocks.23.ln_2.bias", "cond_stage_model.model.transformer.resblocks.23.ln_2.weight", "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias", "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight", "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias", "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight", "cond_stage_model.model.text_projection", ] # To support legacy scheduler_type argument SCHEDULER_DEFAULT_CONFIG = { "beta_schedule": "scaled_linear", "beta_start": 0.00085, "beta_end": 0.012, "interpolation_type": "linear", "num_train_timesteps": 1000, "prediction_type": "epsilon", "sample_max_value": 1.0, "set_alpha_to_one": False, "skip_prk_steps": True, "steps_offset": 1, "timestep_spacing": "leading", } LDM_VAE_KEY = "first_stage_model." LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215 PLAYGROUND_VAE_SCALING_FACTOR = 0.5 LDM_UNET_KEY = "model.diffusion_model." LDM_CONTROLNET_KEY = "control_model." LDM_CLIP_PREFIX_TO_REMOVE = [ "cond_stage_model.transformer.", "conditioner.embedders.0.transformer.", ] OPEN_CLIP_PREFIX = "conditioner.embedders.0.model." LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] class SingleFileComponentError(Exception): def __init__(self, message=None): self.message = message super().__init__(self.message) def is_valid_url(url): result = urlparse(url) if result.scheme and result.netloc: return True return False def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): if not is_valid_url(pretrained_model_name_or_path): raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.") pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)" weights_name = None repo_id = (None,) for prefix in VALID_URL_PREFIXES: pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") match = re.match(pattern, pretrained_model_name_or_path) if not match: logger.warning("Unable to identify the repo_id and weights_name from the provided URL.") return repo_id, weights_name repo_id = f"{match.group(1)}/{match.group(2)}" weights_name = match.group(3) return repo_id, weights_name def _is_model_weights_in_cached_folder(cached_folder, name): pretrained_model_name_or_path = os.path.join(cached_folder, name) weights_exist = False for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]: if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): weights_exist = True return weights_exist def load_single_file_checkpoint( pretrained_model_link_or_path, force_download=False, proxies=None, token=None, cache_dir=None, local_files_only=None, revision=None, ): if os.path.isfile(pretrained_model_link_or_path): pretrained_model_link_or_path = pretrained_model_link_or_path else: repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) pretrained_model_link_or_path = _get_model_file( repo_id, weights_name=weights_name, force_download=force_download, cache_dir=cache_dir, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, ) checkpoint = load_state_dict(pretrained_model_link_or_path) # some checkpoints contain the model state dict under a "state_dict" key while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] return checkpoint def fetch_original_config(original_config_file, local_files_only=False): if os.path.isfile(original_config_file): with open(original_config_file, "r") as fp: original_config_file = fp.read() elif is_valid_url(original_config_file): if local_files_only: raise ValueError( "`local_files_only` is set to True, but a URL was provided as `original_config_file`. " "Please provide a valid local file path." ) original_config_file = BytesIO(requests.get(original_config_file).content) else: raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.") original_config = yaml.safe_load(original_config_file) return original_config def is_clip_model(checkpoint): if CHECKPOINT_KEY_NAMES["clip"] in checkpoint: return True return False def is_clip_sdxl_model(checkpoint): if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint: return True return False def is_clip_sd3_model(checkpoint): if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint: return True return False def is_open_clip_model(checkpoint): if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint: return True return False def is_open_clip_sdxl_model(checkpoint): if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint: return True return False def is_open_clip_sd3_model(checkpoint): if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint: return True return False def is_open_clip_sdxl_refiner_model(checkpoint): if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint: return True return False def is_clip_model_in_single_file(class_obj, checkpoint): is_clip_in_checkpoint = any( [ is_clip_model(checkpoint), is_clip_sd3_model(checkpoint), is_open_clip_model(checkpoint), is_open_clip_sdxl_model(checkpoint), is_open_clip_sdxl_refiner_model(checkpoint), is_open_clip_sd3_model(checkpoint), ] ) if ( class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection" ) and is_clip_in_checkpoint: return True return False def infer_diffusers_model_type(checkpoint): if ( CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9 ): if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: model_type = "inpainting_v2" else: model_type = "inpainting" elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: model_type = "v2" elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint: model_type = "playground-v2-5" elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: model_type = "xl_base" elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint: model_type = "xl_refiner" elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint: model_type = "upscale" elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint: model_type = "controlnet" elif ( CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536 ): model_type = "stable_cascade_stage_c_lite" elif ( CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048 ): model_type = "stable_cascade_stage_c" elif ( CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576 ): model_type = "stable_cascade_stage_b_lite" elif ( CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640 ): model_type = "stable_cascade_stage_b" elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint: model_type = "sd3" elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint: model_type = "animatediff_scribble" elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint: model_type = "animatediff_rgb" elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint: model_type = "animatediff_v2" elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320: model_type = "animatediff_sdxl_beta" elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24: model_type = "animatediff_v1" else: model_type = "animatediff_v3" elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint: if "guidance_in.in_layer.bias" in checkpoint: model_type = "flux-dev" else: model_type = "flux-schnell" else: model_type = "v1" return model_type def fetch_diffusers_config(checkpoint): model_type = infer_diffusers_model_type(checkpoint) model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type] return model_path def set_image_size(checkpoint, image_size=None): if image_size: return image_size model_type = infer_diffusers_model_type(checkpoint) image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type] return image_size # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear def conv_attn_to_linear(checkpoint): keys = list(checkpoint.keys()) attn_keys = ["query.weight", "key.weight", "value.weight"] for key in keys: if ".".join(key.split(".")[-2:]) in attn_keys: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0, 0] elif "proj_attn.weight" in key: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0] def create_unet_diffusers_config_from_ldm( original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None ): """ Creates a config for the diffusers based on the config of the LDM model. """ if image_size is not None: deprecation_message = ( "Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`" "is deprecated and will be ignored in future versions." ) deprecate("image_size", "1.0.0", deprecation_message) image_size = set_image_size(checkpoint, image_size=image_size) if ( "unet_config" in original_config["model"]["params"] and original_config["model"]["params"]["unet_config"] is not None ): unet_params = original_config["model"]["params"]["unet_config"]["params"] else: unet_params = original_config["model"]["params"]["network_config"]["params"] if num_in_channels is not None: deprecation_message = ( "Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`" "is deprecated and will be ignored in future versions." ) deprecate("image_size", "1.0.0", deprecation_message) in_channels = num_in_channels else: in_channels = unet_params["in_channels"] vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 if unet_params["transformer_depth"] is not None: transformer_layers_per_block = ( unet_params["transformer_depth"] if isinstance(unet_params["transformer_depth"], int) else list(unet_params["transformer_depth"]) ) else: transformer_layers_per_block = 1 vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None use_linear_projection = ( unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] class_embed_type = None addition_embed_type = None addition_time_embed_dim = None projection_class_embeddings_input_dim = None context_dim = None if unet_params["context_dim"] is not None: context_dim = ( unet_params["context_dim"] if isinstance(unet_params["context_dim"], int) else unet_params["context_dim"][0] ) if "num_classes" in unet_params: if unet_params["num_classes"] == "sequential": if context_dim in [2048, 1280]: # SDXL addition_embed_type = "text_time" addition_time_embed_dim = 256 else: class_embed_type = "projection" assert "adm_in_channels" in unet_params projection_class_embeddings_input_dim = unet_params["adm_in_channels"] config = { "sample_size": image_size // vae_scale_factor, "in_channels": in_channels, "down_block_types": down_block_types, "block_out_channels": block_out_channels, "layers_per_block": unet_params["num_res_blocks"], "cross_attention_dim": context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "addition_embed_type": addition_embed_type, "addition_time_embed_dim": addition_time_embed_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, "transformer_layers_per_block": transformer_layers_per_block, } if upcast_attention is not None: deprecation_message = ( "Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`" "is deprecated and will be ignored in future versions." ) deprecate("image_size", "1.0.0", deprecation_message) config["upcast_attention"] = upcast_attention if "disable_self_attentions" in unet_params: config["only_cross_attention"] = unet_params["disable_self_attentions"] if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): config["num_class_embeds"] = unet_params["num_classes"] config["out_channels"] = unet_params["out_channels"] config["up_block_types"] = up_block_types return config def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs): if image_size is not None: deprecation_message = ( "Configuring ControlNetModel with the `image_size` argument" "is deprecated and will be ignored in future versions." ) deprecate("image_size", "1.0.0", deprecation_message) image_size = set_image_size(checkpoint, image_size=image_size) unet_params = original_config["model"]["params"]["control_stage_config"]["params"] diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size) controlnet_config = { "conditioning_channels": unet_params["hint_channels"], "in_channels": diffusers_unet_config["in_channels"], "down_block_types": diffusers_unet_config["down_block_types"], "block_out_channels": diffusers_unet_config["block_out_channels"], "layers_per_block": diffusers_unet_config["layers_per_block"], "cross_attention_dim": diffusers_unet_config["cross_attention_dim"], "attention_head_dim": diffusers_unet_config["attention_head_dim"], "use_linear_projection": diffusers_unet_config["use_linear_projection"], "class_embed_type": diffusers_unet_config["class_embed_type"], "addition_embed_type": diffusers_unet_config["addition_embed_type"], "addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"], "projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"], "transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"], } return controlnet_config def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None): """ Creates a config for the diffusers based on the config of the LDM model. """ if image_size is not None: deprecation_message = ( "Configuring AutoencoderKL with the `image_size` argument" "is deprecated and will be ignored in future versions." ) deprecate("image_size", "1.0.0", deprecation_message) image_size = set_image_size(checkpoint, image_size=image_size) if "edm_mean" in checkpoint and "edm_std" in checkpoint: latents_mean = checkpoint["edm_mean"] latents_std = checkpoint["edm_std"] else: latents_mean = None latents_std = None vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None): scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]): scaling_factor = original_config["model"]["params"]["scale_factor"] elif scaling_factor is None: scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) config = { "sample_size": image_size, "in_channels": vae_params["in_channels"], "out_channels": vae_params["out_ch"], "down_block_types": down_block_types, "up_block_types": up_block_types, "block_out_channels": block_out_channels, "latent_channels": vae_params["z_channels"], "layers_per_block": vae_params["num_res_blocks"], "scaling_factor": scaling_factor, } if latents_mean is not None and latents_std is not None: config.update({"latents_mean": latents_mean, "latents_std": latents_std}) return config def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None): for ldm_key in ldm_keys: diffusers_key = ( ldm_key.replace("in_layers.0", "norm1") .replace("in_layers.2", "conv1") .replace("out_layers.0", "norm2") .replace("out_layers.3", "conv2") .replace("emb_layers.1", "time_emb_proj") .replace("skip_connection", "conv_shortcut") ) if mapping: diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"]) new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping): for ldm_key in ldm_keys: diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]) new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): for ldm_key in keys: diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut") new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): for ldm_key in keys: diffusers_key = ( ldm_key.replace(mapping["old"], mapping["new"]) .replace("norm.weight", "group_norm.weight") .replace("norm.bias", "group_norm.bias") .replace("q.weight", "to_q.weight") .replace("q.bias", "to_q.bias") .replace("k.weight", "to_k.weight") .replace("k.bias", "to_k.bias") .replace("v.weight", "to_v.weight") .replace("v.bias", "to_v.bias") .replace("proj_out.weight", "to_out.0.weight") .replace("proj_out.bias", "to_out.0.bias") ) new_checkpoint[diffusers_key] = checkpoint.get(ldm_key) # proj_attn.weight has to be converted from conv 1D to linear shape = new_checkpoint[diffusers_key].shape if len(shape) == 3: new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0] elif len(shape) == 4: new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0] def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs): is_stage_c = "clip_txt_mapper.weight" in checkpoint if is_stage_c: state_dict = {} for key in checkpoint.keys(): if key.endswith("in_proj_weight"): weights = checkpoint[key].chunk(3, 0) state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] elif key.endswith("in_proj_bias"): weights = checkpoint[key].chunk(3, 0) state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] elif key.endswith("out_proj.weight"): weights = checkpoint[key] state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights elif key.endswith("out_proj.bias"): weights = checkpoint[key] state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights else: state_dict[key] = checkpoint[key] else: state_dict = {} for key in checkpoint.keys(): if key.endswith("in_proj_weight"): weights = checkpoint[key].chunk(3, 0) state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] elif key.endswith("in_proj_bias"): weights = checkpoint[key].chunk(3, 0) state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] elif key.endswith("out_proj.weight"): weights = checkpoint[key] state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights elif key.endswith("out_proj.bias"): weights = checkpoint[key] state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights # rename clip_mapper to clip_txt_pooled_mapper elif key.endswith("clip_mapper.weight"): weights = checkpoint[key] state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights elif key.endswith("clip_mapper.bias"): weights = checkpoint[key] state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights else: state_dict[key] = checkpoint[key] return state_dict def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs): """ Takes a state dict and a config, and returns a converted checkpoint. """ # extract state_dict for UNet unet_state_dict = {} keys = list(checkpoint.keys()) unet_key = LDM_UNET_KEY # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: logger.warning("Checkpoint has both EMA and non-EMA weights.") logger.warning( "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." ) for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key) else: if sum(k.startswith("model_ema") for k in keys) > 100: logger.warning( "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" " weights (usually better for inference), please make sure to add the `--extract_ema` flag." ) for key in keys: if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key) new_checkpoint = {} ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"] for diffusers_key, ldm_key in ldm_unet_keys.items(): if ldm_key not in unet_state_dict: continue new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]): class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"] for diffusers_key, ldm_key in class_embed_keys.items(): new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"): addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"] for diffusers_key, ldm_key in addition_embed_keys.items(): new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] # Relevant to StableDiffusionUpscalePipeline if "num_class_embeds" in config: if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] # Retrieves the keys for the input blocks only num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) input_blocks = { layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } # Retrieves the keys for the middle blocks only num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) middle_blocks = { layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } # Retrieves the keys for the output blocks only num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) output_blocks = { layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] for layer_id in range(num_output_blocks) } # Down blocks for i in range(1, num_input_blocks): block_id = (i - 1) // (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) resnets = [ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] update_unet_resnet_ldm_to_diffusers( resnets, new_checkpoint, unet_state_dict, {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, ) if f"input_blocks.{i}.0.op.weight" in unet_state_dict: new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get( f"input_blocks.{i}.0.op.weight" ) new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get( f"input_blocks.{i}.0.op.bias" ) attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if attentions: update_unet_attention_ldm_to_diffusers( attentions, new_checkpoint, unet_state_dict, {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, ) # Mid blocks for key in middle_blocks.keys(): diffusers_key = max(key - 1, 0) if key % 2 == 0: update_unet_resnet_ldm_to_diffusers( middle_blocks[key], new_checkpoint, unet_state_dict, mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, ) else: update_unet_attention_ldm_to_diffusers( middle_blocks[key], new_checkpoint, unet_state_dict, mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, ) # Up Blocks for i in range(num_output_blocks): block_id = i // (config["layers_per_block"] + 1) layer_in_block_id = i % (config["layers_per_block"] + 1) resnets = [ key for key in output_blocks[i] if f"output_blocks.{i}.0" in key and f"output_blocks.{i}.0.op" not in key ] update_unet_resnet_ldm_to_diffusers( resnets, new_checkpoint, unet_state_dict, {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}, ) attentions = [ key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and f"output_blocks.{i}.1.conv" not in key ] if attentions: update_unet_attention_ldm_to_diffusers( attentions, new_checkpoint, unet_state_dict, {"old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}"}, ) if f"output_blocks.{i}.1.conv.weight" in unet_state_dict: new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ f"output_blocks.{i}.1.conv.weight" ] new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ f"output_blocks.{i}.1.conv.bias" ] if f"output_blocks.{i}.2.conv.weight" in unet_state_dict: new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ f"output_blocks.{i}.2.conv.weight" ] new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ f"output_blocks.{i}.2.conv.bias" ] return new_checkpoint def convert_controlnet_checkpoint( checkpoint, config, **kwargs, ): # Some controlnet ckpt files are distributed independently from the rest of the # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ if "time_embed.0.weight" in checkpoint: controlnet_state_dict = checkpoint else: controlnet_state_dict = {} keys = list(checkpoint.keys()) controlnet_key = LDM_CONTROLNET_KEY for key in keys: if key.startswith(controlnet_key): controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key) new_checkpoint = {} ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"] for diffusers_key, ldm_key in ldm_controlnet_keys.items(): if ldm_key not in controlnet_state_dict: continue new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key] # Retrieves the keys for the input blocks only num_input_blocks = len( {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer} ) input_blocks = { layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } # Down blocks for i in range(1, num_input_blocks): block_id = (i - 1) // (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) resnets = [ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] update_unet_resnet_ldm_to_diffusers( resnets, new_checkpoint, controlnet_state_dict, {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, ) if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get( f"input_blocks.{i}.0.op.weight" ) new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get( f"input_blocks.{i}.0.op.bias" ) attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if attentions: update_unet_attention_ldm_to_diffusers( attentions, new_checkpoint, controlnet_state_dict, {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, ) # controlnet down blocks for i in range(num_input_blocks): new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight") new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias") # Retrieves the keys for the middle blocks only num_middle_blocks = len( {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer} ) middle_blocks = { layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } # Mid blocks for key in middle_blocks.keys(): diffusers_key = max(key - 1, 0) if key % 2 == 0: update_unet_resnet_ldm_to_diffusers( middle_blocks[key], new_checkpoint, controlnet_state_dict, mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"}, ) else: update_unet_attention_ldm_to_diffusers( middle_blocks[key], new_checkpoint, controlnet_state_dict, mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"}, ) # mid block new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight") new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias") # controlnet cond embedding blocks cond_embedding_blocks = { ".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) } num_cond_embedding_blocks = len(cond_embedding_blocks) for idx in range(1, num_cond_embedding_blocks + 1): diffusers_idx = idx - 1 cond_block_id = 2 * idx new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get( f"input_hint_block.{cond_block_id}.weight" ) new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get( f"input_hint_block.{cond_block_id}.bias" ) return new_checkpoint def convert_ldm_vae_checkpoint(checkpoint, config): # extract state dict for VAE # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys vae_state_dict = {} keys = list(checkpoint.keys()) vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else "" for key in keys: if key.startswith(vae_key): vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) new_checkpoint = {} vae_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["vae"] for diffusers_key, ldm_key in vae_diffusers_ldm_map.items(): if ldm_key not in vae_state_dict: continue new_checkpoint[diffusers_key] = vae_state_dict[ldm_key] # Retrieves the keys for the encoder down blocks only num_down_blocks = len(config["down_block_types"]) down_blocks = { layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) } for i in range(num_down_blocks): resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] update_vae_resnet_ldm_to_diffusers( resnets, new_checkpoint, vae_state_dict, mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}, ) if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get( f"encoder.down.{i}.downsample.conv.weight" ) new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get( f"encoder.down.{i}.downsample.conv.bias" ) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] update_vae_resnet_ldm_to_diffusers( resnets, new_checkpoint, vae_state_dict, mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, ) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] update_vae_attentions_ldm_to_diffusers( mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} ) # Retrieves the keys for the decoder up blocks only num_up_blocks = len(config["up_block_types"]) up_blocks = { layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) } for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] update_vae_resnet_ldm_to_diffusers( resnets, new_checkpoint, vae_state_dict, mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}, ) if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.weight" ] new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.bias" ] mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] update_vae_resnet_ldm_to_diffusers( resnets, new_checkpoint, vae_state_dict, mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, ) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] update_vae_attentions_ldm_to_diffusers( mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} ) conv_attn_to_linear(new_checkpoint) return new_checkpoint def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None): keys = list(checkpoint.keys()) text_model_dict = {} remove_prefixes = [] remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE) if remove_prefix: remove_prefixes.append(remove_prefix) for key in keys: for prefix in remove_prefixes: if key.startswith(prefix): diffusers_key = key.replace(prefix, "") text_model_dict[diffusers_key] = checkpoint.get(key) return text_model_dict def convert_open_clip_checkpoint( text_model, checkpoint, prefix="cond_stage_model.model.", ): text_model_dict = {} text_proj_key = prefix + "text_projection" if text_proj_key in checkpoint: text_proj_dim = int(checkpoint[text_proj_key].shape[0]) elif hasattr(text_model.config, "projection_dim"): text_proj_dim = text_model.config.projection_dim else: text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM keys = list(checkpoint.keys()) keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"] for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items(): ldm_key = prefix + ldm_key if ldm_key not in checkpoint: continue if ldm_key in keys_to_ignore: continue if ldm_key.endswith("text_projection"): text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous() else: text_model_dict[diffusers_key] = checkpoint[ldm_key] for key in keys: if key in keys_to_ignore: continue if not key.startswith(prefix + "transformer."): continue diffusers_key = key.replace(prefix + "transformer.", "") transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"] for new_key, old_key in transformer_diffusers_to_ldm_map.items(): diffusers_key = ( diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "") ) if key.endswith(".in_proj_weight"): weight_value = checkpoint.get(key) text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach() text_model_dict[diffusers_key + ".k_proj.weight"] = ( weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach() ) text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach() elif key.endswith(".in_proj_bias"): weight_value = checkpoint.get(key) text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach() text_model_dict[diffusers_key + ".k_proj.bias"] = ( weight_value[text_proj_dim : text_proj_dim * 2].clone().detach() ) text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach() else: text_model_dict[diffusers_key] = checkpoint.get(key) return text_model_dict def create_diffusers_clip_model_from_ldm( cls, checkpoint, subfolder="", config=None, torch_dtype=None, local_files_only=None, is_legacy_loading=False, ): if config: config = {"pretrained_model_name_or_path": config} else: config = fetch_diffusers_config(checkpoint) # For backwards compatibility # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo # in the cache_dir, rather than in a subfolder of the Diffusers model if is_legacy_loading: logger.warning( ( "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update " "the local cache directory with the necessary CLIP model config files. " "Attempting to load CLIP model from legacy cache directory." ) ) if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): clip_config = "openai/clip-vit-large-patch14" config["pretrained_model_name_or_path"] = clip_config subfolder = "" elif is_open_clip_model(checkpoint): clip_config = "stabilityai/stable-diffusion-2" config["pretrained_model_name_or_path"] = clip_config subfolder = "text_encoder" else: clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" config["pretrained_model_name_or_path"] = clip_config subfolder = "" model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): model = cls(model_config) position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1] if is_clip_model(checkpoint): diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) elif ( is_clip_sdxl_model(checkpoint) and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim ): diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) elif ( is_clip_sd3_model(checkpoint) and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim ): diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.") diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim) elif is_open_clip_model(checkpoint): prefix = "cond_stage_model.model." diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) elif ( is_open_clip_sdxl_model(checkpoint) and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim ): prefix = "conditioner.embedders.1.model." diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) elif is_open_clip_sdxl_refiner_model(checkpoint): prefix = "conditioner.embedders.0.model." diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) elif ( is_open_clip_sd3_model(checkpoint) and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim ): diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.") else: raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") if is_accelerate_available(): unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) else: _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) if model._keys_to_ignore_on_load_unexpected is not None: for pat in model._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) if torch_dtype is not None: model.to(torch_dtype) model.eval() return model def _legacy_load_scheduler( cls, checkpoint, component_name, original_config=None, **kwargs, ): scheduler_type = kwargs.get("scheduler_type", None) prediction_type = kwargs.get("prediction_type", None) if scheduler_type is not None: deprecation_message = ( "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`." ) deprecate("scheduler_type", "1.0.0", deprecation_message) if prediction_type is not None: deprecation_message = ( "Please configure an instance of a Scheduler with the appropriate `prediction_type` " "and pass the object directly to the `scheduler` argument in `from_single_file`." ) deprecate("prediction_type", "1.0.0", deprecation_message) scheduler_config = SCHEDULER_DEFAULT_CONFIG model_type = infer_diffusers_model_type(checkpoint=checkpoint) global_step = checkpoint["global_step"] if "global_step" in checkpoint else None if original_config: num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000) else: num_train_timesteps = 1000 scheduler_config["num_train_timesteps"] = num_train_timesteps if model_type == "v2": if prediction_type is None: # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here prediction_type = "epsilon" if global_step == 875000 else "v_prediction" else: prediction_type = prediction_type or "epsilon" scheduler_config["prediction_type"] = prediction_type if model_type in ["xl_base", "xl_refiner"]: scheduler_type = "euler" elif model_type == "playground": scheduler_type = "edm_dpm_solver_multistep" else: if original_config: beta_start = original_config["model"]["params"].get("linear_start") beta_end = original_config["model"]["params"].get("linear_end") else: beta_start = 0.02 beta_end = 0.085 scheduler_config["beta_start"] = beta_start scheduler_config["beta_end"] = beta_end scheduler_config["beta_schedule"] = "scaled_linear" scheduler_config["clip_sample"] = False scheduler_config["set_alpha_to_one"] = False # to deal with an edge case StableDiffusionUpscale pipeline has two schedulers if component_name == "low_res_scheduler": return cls.from_config( { "beta_end": 0.02, "beta_schedule": "scaled_linear", "beta_start": 0.0001, "clip_sample": True, "num_train_timesteps": 1000, "prediction_type": "epsilon", "trained_betas": None, "variance_type": "fixed_small", } ) if scheduler_type is None: return cls.from_config(scheduler_config) elif scheduler_type == "pndm": scheduler_config["skip_prk_steps"] = True scheduler = PNDMScheduler.from_config(scheduler_config) elif scheduler_type == "lms": scheduler = LMSDiscreteScheduler.from_config(scheduler_config) elif scheduler_type == "heun": scheduler = HeunDiscreteScheduler.from_config(scheduler_config) elif scheduler_type == "euler": scheduler = EulerDiscreteScheduler.from_config(scheduler_config) elif scheduler_type == "euler-ancestral": scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) elif scheduler_type == "dpm": scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) elif scheduler_type == "ddim": scheduler = DDIMScheduler.from_config(scheduler_config) elif scheduler_type == "edm_dpm_solver_multistep": scheduler_config = { "algorithm_type": "dpmsolver++", "dynamic_thresholding_ratio": 0.995, "euler_at_final": False, "final_sigmas_type": "zero", "lower_order_final": True, "num_train_timesteps": 1000, "prediction_type": "epsilon", "rho": 7.0, "sample_max_value": 1.0, "sigma_data": 0.5, "sigma_max": 80.0, "sigma_min": 0.002, "solver_order": 2, "solver_type": "midpoint", "thresholding": False, } scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config) else: raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") return scheduler def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False): if config: config = {"pretrained_model_name_or_path": config} else: config = fetch_diffusers_config(checkpoint) if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint): clip_config = "openai/clip-vit-large-patch14" config["pretrained_model_name_or_path"] = clip_config subfolder = "" elif is_open_clip_model(checkpoint): clip_config = "stabilityai/stable-diffusion-2" config["pretrained_model_name_or_path"] = clip_config subfolder = "tokenizer" else: clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" config["pretrained_model_name_or_path"] = clip_config subfolder = "" tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) return tokenizer def _legacy_load_safety_checker(local_files_only, torch_dtype): # Support for loading safety checker components using the deprecated # `load_safety_checker` argument. from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker feature_extractor = AutoImageProcessor.from_pretrained( "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype ) safety_checker = StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype ) return {"safety_checker": safety_checker, "feature_extractor": feature_extractor} # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation def swap_scale_shift(weight, dim): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) return new_weight def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} keys = list(checkpoint.keys()) for k in keys: if "model.diffusion_model." in k: checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401 caption_projection_dim = 1536 # Positional and patch embeddings. converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed") converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") # Timestep embeddings. converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( "t_embedder.mlp.0.weight" ) converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( "t_embedder.mlp.2.weight" ) converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") # Context projections. converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight") converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias") # Pooled context projection. converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight") converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias") converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight") converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias") # Transformer blocks 🎸. for i in range(num_layers): # Q, K, V sample_q, sample_k, sample_v = torch.chunk( checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0 ) context_q, context_k, context_v = torch.chunk( checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0 ) sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0 ) context_q_bias, context_k_bias, context_v_bias = torch.chunk( checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0 ) converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q]) converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias]) converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k]) converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias]) converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v]) converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias]) converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q]) converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias]) converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k]) converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias]) converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) # output projections. converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop( f"joint_blocks.{i}.x_block.attn.proj.weight" ) converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop( f"joint_blocks.{i}.x_block.attn.proj.bias" ) if not (i == num_layers - 1): converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop( f"joint_blocks.{i}.context_block.attn.proj.weight" ) converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop( f"joint_blocks.{i}.context_block.attn.proj.bias" ) # norms. converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" ) converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop( f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias" ) if not (i == num_layers - 1): converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop( f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight" ) converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop( f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias" ) else: converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift( checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"), dim=caption_projection_dim, ) converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift( checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"), dim=caption_projection_dim, ) # ffs. converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop( f"joint_blocks.{i}.x_block.mlp.fc1.weight" ) converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop( f"joint_blocks.{i}.x_block.mlp.fc1.bias" ) converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop( f"joint_blocks.{i}.x_block.mlp.fc2.weight" ) converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop( f"joint_blocks.{i}.x_block.mlp.fc2.bias" ) if not (i == num_layers - 1): converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop( f"joint_blocks.{i}.context_block.mlp.fc1.weight" ) converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop( f"joint_blocks.{i}.context_block.mlp.fc1.bias" ) converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop( f"joint_blocks.{i}.context_block.mlp.fc2.weight" ) converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop( f"joint_blocks.{i}.context_block.mlp.fc2.bias" ) # Final blocks. converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim ) converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim ) return converted_state_dict def is_t5_in_single_file(checkpoint): if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint: return True return False def convert_sd3_t5_checkpoint_to_diffusers(checkpoint): keys = list(checkpoint.keys()) text_model_dict = {} remove_prefixes = ["text_encoders.t5xxl.transformer."] for key in keys: for prefix in remove_prefixes: if key.startswith(prefix): diffusers_key = key.replace(prefix, "") text_model_dict[diffusers_key] = checkpoint.get(key) return text_model_dict def create_diffusers_t5_model_from_checkpoint( cls, checkpoint, subfolder="", config=None, torch_dtype=None, local_files_only=None, ): if config: config = {"pretrained_model_name_or_path": config} else: config = fetch_diffusers_config(checkpoint) model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only) ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): model = cls(model_config) diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) if is_accelerate_available(): unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) if model._keys_to_ignore_on_load_unexpected is not None: for pat in model._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) else: model.load_state_dict(diffusers_format_checkpoint) use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16) if use_keep_in_fp32_modules: keep_in_fp32_modules = model._keep_in_fp32_modules else: keep_in_fp32_modules = [] if keep_in_fp32_modules is not None: for name, param in model.named_parameters(): if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): # param = param.to(torch.float32) does not work here as only in the local scope. param.data = param.data.to(torch.float32) return model def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} for k, v in checkpoint.items(): if "pos_encoder" in k: continue else: converted_state_dict[ k.replace(".norms.0", ".norm1") .replace(".norms.1", ".norm2") .replace(".ff_norm", ".norm3") .replace(".attention_blocks.0", ".attn1") .replace(".attention_blocks.1", ".attn2") .replace(".temporal_transformer", "") ] = v return converted_state_dict def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 mlp_ratio = 4.0 inner_dim = 3072 # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) return new_weight ## time_text_embed.timestep_embedder <- time_in converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop( "time_in.in_layer.weight" ) converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias") converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop( "time_in.out_layer.weight" ) converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias") ## time_text_embed.text_embedder <- vector_in converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight") converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias") converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop( "vector_in.out_layer.weight" ) converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias") # guidance has_guidance = any("guidance" in k for k in checkpoint) if has_guidance: converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop( "guidance_in.in_layer.weight" ) converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop( "guidance_in.in_layer.bias" ) converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop( "guidance_in.out_layer.weight" ) converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop( "guidance_in.out_layer.bias" ) # context_embedder converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") # x_embedder converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") # double transformer blocks for i in range(num_layers): block_prefix = f"transformer_blocks.{i}." # norms. ## norm1 converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop( f"double_blocks.{i}.img_mod.lin.weight" ) converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop( f"double_blocks.{i}.img_mod.lin.bias" ) ## norm1_context converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop( f"double_blocks.{i}.txt_mod.lin.weight" ) converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop( f"double_blocks.{i}.txt_mod.lin.bias" ) # Q, K, V sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) context_q, context_k, context_v = torch.chunk( checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 ) sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 ) context_q_bias, context_k_bias, context_v_bias = torch.chunk( checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 ) converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) # qk_norm converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( f"double_blocks.{i}.img_attn.norm.query_norm.scale" ) converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( f"double_blocks.{i}.img_attn.norm.key_norm.scale" ) converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop( f"double_blocks.{i}.txt_attn.norm.query_norm.scale" ) converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop( f"double_blocks.{i}.txt_attn.norm.key_norm.scale" ) # ff img_mlp converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop( f"double_blocks.{i}.img_mlp.0.weight" ) converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias") converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight") converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias") converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop( f"double_blocks.{i}.txt_mlp.0.weight" ) converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop( f"double_blocks.{i}.txt_mlp.0.bias" ) converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop( f"double_blocks.{i}.txt_mlp.2.weight" ) converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop( f"double_blocks.{i}.txt_mlp.2.bias" ) # output projections. converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop( f"double_blocks.{i}.img_attn.proj.weight" ) converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop( f"double_blocks.{i}.img_attn.proj.bias" ) converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop( f"double_blocks.{i}.txt_attn.proj.weight" ) converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop( f"double_blocks.{i}.txt_attn.proj.bias" ) # single transfomer blocks for i in range(num_single_layers): block_prefix = f"single_transformer_blocks.{i}." # norm.linear <- single_blocks.0.modulation.lin converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop( f"single_blocks.{i}.modulation.lin.weight" ) converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop( f"single_blocks.{i}.modulation.lin.bias" ) # Q, K, V, mlp mlp_hidden_dim = int(inner_dim * mlp_ratio) split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) q_bias, k_bias, v_bias, mlp_bias = torch.split( checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 ) converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) # qk norm converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( f"single_blocks.{i}.norm.query_norm.scale" ) converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( f"single_blocks.{i}.norm.key_norm.scale" ) # output projections. converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( checkpoint.pop("final_layer.adaLN_modulation.1.weight") ) converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( checkpoint.pop("final_layer.adaLN_modulation.1.bias") ) return converted_state_dict ================================================ FILE: src/diffusers/loaders/textual_inversion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Optional, Union import safetensors import torch from huggingface_hub.utils import validate_hf_hub_args from torch import nn from ..models.modeling_utils import load_state_dict from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging if is_transformers_available(): from transformers import PreTrainedModel, PreTrainedTokenizer if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module logger = logging.get_logger(__name__) TEXT_INVERSION_NAME = "learned_embeds.bin" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" @validate_hf_hub_args def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs): cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True user_agent = { "file_type": "text_inversion", "framework": "pytorch", } state_dicts = [] for pretrained_model_name_or_path in pretrained_model_name_or_paths: if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)): # 3.1. Load textual inversion file model_file = None # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except Exception as e: if not allow_pickle: raise e model_file = None if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=weight_name or TEXT_INVERSION_NAME, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = load_state_dict(model_file) else: state_dict = pretrained_model_name_or_path state_dicts.append(state_dict) return state_dicts class TextualInversionLoaderMixin: r""" Load Textual Inversion tokens and embeddings to the tokenizer and text encoder. """ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821 r""" Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual inversion token or if the textual inversion token is a single vector, the input prompt is returned. Parameters: prompt (`str` or list of `str`): The prompt or prompts to guide the image generation. tokenizer (`PreTrainedTokenizer`): The tokenizer responsible for encoding the prompt into input tokens. Returns: `str` or list of `str`: The converted prompt """ if not isinstance(prompt, List): prompts = [prompt] else: prompts = prompt prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts] if not isinstance(prompt, List): return prompts[0] return prompts def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821 r""" Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds to a multi-vector textual inversion embedding, this function will process the prompt so that the special token is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. Parameters: prompt (`str`): The prompt to guide the image generation. tokenizer (`PreTrainedTokenizer`): The tokenizer responsible for encoding the prompt into input tokens. Returns: `str`: The converted prompt """ tokens = tokenizer.tokenize(prompt) unique_tokens = set(tokens) for token in unique_tokens: if token in tokenizer.added_tokens_encoder: replacement = token i = 1 while f"{token}_{i}" in tokenizer.added_tokens_encoder: replacement += f" {token}_{i}" i += 1 prompt = prompt.replace(token, replacement) return prompt def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens): if tokenizer is None: raise ValueError( f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling" f" `{self.load_textual_inversion.__name__}`" ) if text_encoder is None: raise ValueError( f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling" f" `{self.load_textual_inversion.__name__}`" ) if len(pretrained_model_name_or_paths) > 1 and len(pretrained_model_name_or_paths) != len(tokens): raise ValueError( f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} " f"Make sure both lists have the same length." ) valid_tokens = [t for t in tokens if t is not None] if len(set(valid_tokens)) < len(valid_tokens): raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}") @staticmethod def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer): all_tokens = [] all_embeddings = [] for state_dict, token in zip(state_dicts, tokens): if isinstance(state_dict, torch.Tensor): if token is None: raise ValueError( "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." ) loaded_token = token embedding = state_dict elif len(state_dict) == 1: # diffusers loaded_token, embedding = next(iter(state_dict.items())) elif "string_to_param" in state_dict: # A1111 loaded_token = state_dict["name"] embedding = state_dict["string_to_param"]["*"] else: raise ValueError( f"Loaded state dictionary is incorrect: {state_dict}. \n\n" "Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`" " input key." ) if token is not None and loaded_token != token: logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") else: token = loaded_token if token in tokenizer.get_vocab(): raise ValueError( f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." ) all_tokens.append(token) all_embeddings.append(embedding) return all_tokens, all_embeddings @staticmethod def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer): all_tokens = [] all_embeddings = [] for embedding, token in zip(embeddings, tokens): if f"{token}_1" in tokenizer.get_vocab(): multi_vector_tokens = [token] i = 1 while f"{token}_{i}" in tokenizer.added_tokens_encoder: multi_vector_tokens.append(f"{token}_{i}") i += 1 raise ValueError( f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." ) is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 if is_multi_vector: all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] all_embeddings += [e for e in embedding] # noqa: C416 else: all_tokens += [token] all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding] return all_tokens, all_embeddings @validate_hf_hub_args def load_textual_inversion( self, pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], token: Optional[Union[str, List[str]]] = None, tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821 text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821 **kwargs, ): r""" Load Textual Inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and Automatic1111 formats are supported). Parameters: pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`): Can be either one of the following or a list of them: - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual inversion weights. - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). token (`str` or `List[str]`, *optional*): Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a list, then `token` must also be a list of equal length. text_encoder ([`~transformers.CLIPTextModel`], *optional*): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). If not specified, function will take self.tokenizer. tokenizer ([`~transformers.CLIPTokenizer`], *optional*): A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer. weight_name (`str`, *optional*): Name of a custom weight file. This should be used when: - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight name such as `text_inv.bin`. - The saved textual inversion file is in the Automatic1111 format. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. Example: To load a Textual Inversion embedding vector in 🤗 Diffusers format: ```py from diffusers import StableDiffusionPipeline import torch model_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") pipe.load_textual_inversion("sd-concepts-library/cat-toy") prompt = "A backpack" image = pipe(prompt, num_inference_steps=50).images[0] image.save("cat-backpack.png") ``` To load a Textual Inversion embedding vector in Automatic1111 format, make sure to download the vector first (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector locally: ```py from diffusers import StableDiffusionPipeline import torch model_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2") prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details." image = pipe(prompt, num_inference_steps=50).images[0] image.save("character.png") ``` """ # 1. Set correct tokenizer and text encoder tokenizer = tokenizer or getattr(self, "tokenizer", None) text_encoder = text_encoder or getattr(self, "text_encoder", None) # 2. Normalize inputs pretrained_model_name_or_paths = ( [pretrained_model_name_or_path] if not isinstance(pretrained_model_name_or_path, list) else pretrained_model_name_or_path ) tokens = [token] if not isinstance(token, list) else token if tokens[0] is None: tokens = tokens * len(pretrained_model_name_or_paths) # 3. Check inputs self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens) # 4. Load state dicts of textual embeddings state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs) # 4.1 Handle the special case when state_dict is a tensor that contains n embeddings for n tokens if len(tokens) > 1 and len(state_dicts) == 1: if isinstance(state_dicts[0], torch.Tensor): state_dicts = list(state_dicts[0]) if len(tokens) != len(state_dicts): raise ValueError( f"You have passed a state_dict contains {len(state_dicts)} embeddings, and list of tokens of length {len(tokens)} " f"Make sure both have the same length." ) # 4. Retrieve tokens and embeddings tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer) # 5. Extend tokens and embeddings for multi vector tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer) # 6. Make sure all embeddings have the correct size expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1] if any(expected_emb_dim != emb.shape[-1] for emb in embeddings): raise ValueError( "Loaded embeddings are of incorrect shape. Expected each textual inversion embedding " "to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} " ) # 7. Now we can be sure that loading the embedding matrix works # < Unsafe code: # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again is_model_cpu_offload = False is_sequential_cpu_offload = False if self.hf_device_map is None: for _, component in self.components.items(): if isinstance(component, nn.Module): if hasattr(component, "_hf_hook"): is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) is_sequential_cpu_offload = ( isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) or hasattr(component._hf_hook, "hooks") and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) ) logger.info( "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." ) remove_hook_from_module(component, recurse=is_sequential_cpu_offload) # 7.2 save expected device and dtype device = text_encoder.device dtype = text_encoder.dtype # 7.3 Increase token embedding matrix text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens)) input_embeddings = text_encoder.get_input_embeddings().weight # 7.4 Load token and embedding for token, embedding in zip(tokens, embeddings): # add tokens and get ids tokenizer.add_tokens(token) token_id = tokenizer.convert_tokens_to_ids(token) input_embeddings.data[token_id] = embedding logger.info(f"Loaded textual inversion embedding for {token}.") input_embeddings.to(dtype=dtype, device=device) # 7.5 Offload the model again if is_model_cpu_offload: self.enable_model_cpu_offload() elif is_sequential_cpu_offload: self.enable_sequential_cpu_offload() # / Unsafe Code > def unload_textual_inversion( self, tokens: Optional[Union[str, List[str]]] = None, tokenizer: Optional["PreTrainedTokenizer"] = None, text_encoder: Optional["PreTrainedModel"] = None, ): r""" Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`] Example: ```py from diffusers import AutoPipelineForText2Image import torch pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5") # Example 1 pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork") pipeline.load_textual_inversion("sd-concepts-library/moeb-style") # Remove all token embeddings pipeline.unload_textual_inversion() # Example 2 pipeline.load_textual_inversion("sd-concepts-library/moeb-style") pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork") # Remove just one token pipeline.unload_textual_inversion("") # Example 3: unload from SDXL pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") embedding_path = hf_hub_download( repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model" ) # load embeddings to the text encoders state_dict = load_file(embedding_path) # load embeddings of text_encoder 1 (CLIP ViT-L/14) pipeline.load_textual_inversion( state_dict["clip_l"], token=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, ) # load embeddings of text_encoder 2 (CLIP ViT-G/14) pipeline.load_textual_inversion( state_dict["clip_g"], token=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2, ) # Unload explicitly from both text encoders abd tokenizers pipeline.unload_textual_inversion( tokens=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer ) pipeline.unload_textual_inversion( tokens=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2 ) ``` """ tokenizer = tokenizer or getattr(self, "tokenizer", None) text_encoder = text_encoder or getattr(self, "text_encoder", None) # Get textual inversion tokens and ids token_ids = [] last_special_token_id = None if tokens: if isinstance(tokens, str): tokens = [tokens] for added_token_id, added_token in tokenizer.added_tokens_decoder.items(): if not added_token.special: if added_token.content in tokens: token_ids.append(added_token_id) else: last_special_token_id = added_token_id if len(token_ids) == 0: raise ValueError("No tokens to remove found") else: tokens = [] for added_token_id, added_token in tokenizer.added_tokens_decoder.items(): if not added_token.special: token_ids.append(added_token_id) tokens.append(added_token.content) else: last_special_token_id = added_token_id # Delete from tokenizer for token_id, token_to_remove in zip(token_ids, tokens): del tokenizer._added_tokens_decoder[token_id] del tokenizer._added_tokens_encoder[token_to_remove] # Make all token ids sequential in tokenizer key_id = 1 for token_id in tokenizer.added_tokens_decoder: if token_id > last_special_token_id and token_id > last_special_token_id + key_id: token = tokenizer._added_tokens_decoder[token_id] tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token del tokenizer._added_tokens_decoder[token_id] tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id key_id += 1 tokenizer._update_trie() # Delete from text encoder text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim temp_text_embedding_weights = text_encoder.get_input_embeddings().weight text_embedding_weights = temp_text_embedding_weights[: last_special_token_id + 1] to_append = [] for i in range(last_special_token_id + 1, temp_text_embedding_weights.shape[0]): if i not in token_ids: to_append.append(temp_text_embedding_weights[i].unsqueeze(0)) if len(to_append) > 0: to_append = torch.cat(to_append, dim=0) text_embedding_weights = torch.cat([text_embedding_weights, to_append], dim=0) text_embeddings_filtered = nn.Embedding(text_embedding_weights.shape[0], text_embedding_dim) text_embeddings_filtered.weight.data = text_embedding_weights text_encoder.set_input_embeddings(text_embeddings_filtered) ================================================ FILE: src/diffusers/loaders/unet.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from collections import defaultdict from contextlib import nullcontext from pathlib import Path from typing import Callable, Dict, Union import safetensors import torch import torch.nn.functional as F from huggingface_hub.utils import validate_hf_hub_args from torch import nn from ..models.embeddings import ( ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterFaceIDPlusImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict from ..utils import ( USE_PEFT_BACKEND, _get_model_file, convert_unet_state_dict_to_peft, get_adapter_name, get_peft_kwargs, is_accelerate_available, is_peft_version, is_torch_version, logging, ) from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .utils import AttnProcsLayers if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module logger = logging.get_logger(__name__) CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" class UNet2DConditionLoadersMixin: """ Load LoRA layers into a [`UNet2DCondtionModel`]. """ text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME @validate_hf_hub_args def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be defined in [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py) and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install `peft`: `pip install -U peft`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a directory (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. network_alphas (`Dict[str, float]`): The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). adapter_name (`str`, *optional*, defaults to None): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. Example: ```py from diffusers import AutoPipelineForText2Image import torch pipeline = AutoPipelineForText2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ).to("cuda") pipeline.unet.load_attn_procs( "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" ) ``` """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) adapter_name = kwargs.pop("adapter_name", None) _pipeline = kwargs.pop("_pipeline", None) network_alphas = kwargs.pop("network_alphas", None) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except IOError as e: if not allow_pickle: raise e # try loading non-safetensors weights pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = load_state_dict(model_file) else: state_dict = pretrained_model_name_or_path_or_dict is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) is_model_cpu_offload = False is_sequential_cpu_offload = False if is_custom_diffusion: attn_processors = self._process_custom_diffusion(state_dict=state_dict) elif is_lora: is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora( state_dict=state_dict, unet_identifier_key=self.unet_name, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline, ) else: raise ValueError( f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training." ) # def _process_custom_diffusion(self, state_dict): from ..models.attention_processor import CustomDiffusionAttnProcessor attn_processors = {} custom_diffusion_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): if len(value) == 0: custom_diffusion_grouped_dict[key] = {} else: if "to_out" in key: attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) else: attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:]) custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value for key, value_dict in custom_diffusion_grouped_dict.items(): if len(value_dict) == 0: attn_processors[key] = CustomDiffusionAttnProcessor( train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None ) else: cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0] train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False attn_processors[key] = CustomDiffusionAttnProcessor( train_kv=True, train_q_out=train_q_out, hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, ) attn_processors[key].load_state_dict(value_dict) return attn_processors def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): # This method does the following things: # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy # format. For legacy format no filtering is applied. # 2. Converts the `state_dict` to the `peft` compatible format. # 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the # `LoraConfig` specs. # 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it. if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) unet_keys = [k for k in keys if k.startswith(unet_identifier_key)] unet_state_dict = { k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys } if network_alphas is not None: alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)] network_alphas = { k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } is_model_cpu_offload = False is_sequential_cpu_offload = False state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict if len(state_dict_to_be_used) > 0: if adapter_name in getattr(self, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." ) state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used) if network_alphas is not None: # The alphas state dict have the same structure as Unet, thus we convert it to peft format using # `convert_unet_state_dict_to_peft` method. network_alphas = convert_unet_state_dict_to_peft(network_alphas) rank = {} for key, val in state_dict.items(): if "lora_B" in key: rank[key] = val.shape[1] lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: if is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) else: if is_peft_version("<", "0.9.0"): lora_config_kwargs.pop("use_dora") lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: adapter_name = get_adapter_name(self) # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) inject_adapter_in_model(lora_config, self, adapter_name=adapter_name) incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: logger.warning( f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) return is_model_cpu_offload, is_sequential_cpu_offload @classmethod # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading def _optionally_disable_offloading(cls, _pipeline): """ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. Args: _pipeline (`DiffusionPipeline`): The pipeline to disable offloading for. Returns: tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. """ is_model_cpu_offload = False is_sequential_cpu_offload = False if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) if not is_sequential_cpu_offload: is_sequential_cpu_offload = ( isinstance(component._hf_hook, AlignDevicesHook) or hasattr(component._hf_hook, "hooks") and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) ) logger.info( "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." ) remove_hook_from_module(component, recurse=is_sequential_cpu_offload) return (is_model_cpu_offload, is_sequential_cpu_offload) def save_attn_procs( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, **kwargs, ): r""" Save attention processor layers to a directory so that it can be reloaded with the [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save an attention processor to (will be created if it doesn't exist). is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or with `pickle`. Example: ```py import torch from diffusers import DiffusionPipeline pipeline = DiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, ).to("cuda") pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin") pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin") ``` """ from ..models.attention_processor import ( CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, ) if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return is_custom_diffusion = any( isinstance( x, (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor), ) for (_, x) in self.attn_processors.items() ) if is_custom_diffusion: state_dict = self._get_custom_diffusion_state_dict() if save_function is None and safe_serialization: # safetensors does not support saving dicts with non-tensor values empty_state_dict = {k: v for k, v in state_dict.items() if not isinstance(v, torch.Tensor)} if len(empty_state_dict) > 0: logger.warning( f"Safetensors does not support saving dicts with non-tensor values. " f"The following keys will be ignored: {empty_state_dict.keys()}" ) state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} else: if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.") from peft.utils import get_peft_model_state_dict state_dict = get_peft_model_state_dict(self) if save_function is None: if safe_serialization: def save_function(weights, filename): return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) else: save_function = torch.save os.makedirs(save_directory, exist_ok=True) if weight_name is None: if safe_serialization: weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE else: weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME # Save the model save_path = Path(save_directory, weight_name).as_posix() save_function(state_dict, save_path) logger.info(f"Model weights saved in {save_path}") def _get_custom_diffusion_state_dict(self): from ..models.attention_processor import ( CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, ) model_to_save = AttnProcsLayers( { y: x for (y, x) in self.attn_processors.items() if isinstance( x, ( CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, ), ) } ) state_dict = model_to_save.state_dict() for name, attn in self.attn_processors.items(): if len(attn.state_dict()) == 0: state_dict[name] = {} return state_dict def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False): if low_cpu_mem_usage: if is_accelerate_available(): from accelerate import init_empty_weights else: low_cpu_mem_usage = False logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" " install accelerate\n```\n." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." ) updated_state_dict = {} image_projection = None init_context = init_empty_weights if low_cpu_mem_usage else nullcontext if "proj.weight" in state_dict: # IP-Adapter num_image_text_embeds = 4 clip_embeddings_dim = state_dict["proj.weight"].shape[-1] cross_attention_dim = state_dict["proj.weight"].shape[0] // 4 with init_context(): image_projection = ImageProjection( cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=num_image_text_embeds, ) for key, value in state_dict.items(): diffusers_name = key.replace("proj", "image_embeds") updated_state_dict[diffusers_name] = value elif "proj.3.weight" in state_dict: # IP-Adapter Full clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] cross_attention_dim = state_dict["proj.3.weight"].shape[0] with init_context(): image_projection = IPAdapterFullImageProjection( cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim ) for key, value in state_dict.items(): diffusers_name = key.replace("proj.0", "ff.net.0.proj") diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") diffusers_name = diffusers_name.replace("proj.3", "norm") updated_state_dict[diffusers_name] = value elif "perceiver_resampler.proj_in.weight" in state_dict: # IP-Adapter Face ID Plus id_embeddings_dim = state_dict["proj.0.weight"].shape[1] embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0] hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1] output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0] heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64 with init_context(): image_projection = IPAdapterFaceIDPlusImageProjection( embed_dims=embed_dims, output_dims=output_dims, hidden_dims=hidden_dims, heads=heads, id_embeddings_dim=id_embeddings_dim, ) for key, value in state_dict.items(): diffusers_name = key.replace("perceiver_resampler.", "") diffusers_name = diffusers_name.replace("0.to", "attn.to") diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.") diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight") diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight") diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.") diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight") diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight") diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.") diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight") diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight") diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.") diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight") diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight") diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0") diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1") diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0") diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1") diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0") diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1") diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0") diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1") if "norm1" in diffusers_name: updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value elif "norm2" in diffusers_name: updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value elif "to_kv" in diffusers_name: v_chunk = value.chunk(2, dim=0) updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] elif "to_out" in diffusers_name: updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value elif "proj.0.weight" == diffusers_name: updated_state_dict["proj.net.0.proj.weight"] = value elif "proj.0.bias" == diffusers_name: updated_state_dict["proj.net.0.proj.bias"] = value elif "proj.2.weight" == diffusers_name: updated_state_dict["proj.net.2.weight"] = value elif "proj.2.bias" == diffusers_name: updated_state_dict["proj.net.2.bias"] = value else: updated_state_dict[diffusers_name] = value elif "norm.weight" in state_dict: # IP-Adapter Face ID id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] multiplier = id_embeddings_dim_out // id_embeddings_dim_in norm_layer = "norm.weight" cross_attention_dim = state_dict[norm_layer].shape[0] num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim with init_context(): image_projection = IPAdapterFaceIDImageProjection( cross_attention_dim=cross_attention_dim, image_embed_dim=id_embeddings_dim_in, mult=multiplier, num_tokens=num_tokens, ) for key, value in state_dict.items(): diffusers_name = key.replace("proj.0", "ff.net.0.proj") diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") updated_state_dict[diffusers_name] = value else: # IP-Adapter Plus num_image_text_embeds = state_dict["latents"].shape[1] embed_dims = state_dict["proj_in.weight"].shape[1] output_dims = state_dict["proj_out.weight"].shape[0] hidden_dims = state_dict["latents"].shape[2] attn_key_present = any("attn" in k for k in state_dict) heads = ( state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 if attn_key_present else state_dict["layers.0.0.to_q.weight"].shape[0] // 64 ) with init_context(): image_projection = IPAdapterPlusImageProjection( embed_dims=embed_dims, output_dims=output_dims, hidden_dims=hidden_dims, heads=heads, num_queries=num_image_text_embeds, ) for key, value in state_dict.items(): diffusers_name = key.replace("0.to", "2.to") diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0") diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1") diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0") diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1") diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0") diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1") diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0") diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1") if "to_kv" in diffusers_name: parts = diffusers_name.split(".") parts[2] = "attn" diffusers_name = ".".join(parts) v_chunk = value.chunk(2, dim=0) updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] elif "to_q" in diffusers_name: parts = diffusers_name.split(".") parts[2] = "attn" diffusers_name = ".".join(parts) updated_state_dict[diffusers_name] = value elif "to_out" in diffusers_name: parts = diffusers_name.split(".") parts[2] = "attn" diffusers_name = ".".join(parts) updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value else: diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0") diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj") diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2") diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0") diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj") diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2") diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0") diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj") diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2") diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0") diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj") diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2") updated_state_dict[diffusers_name] = value if not low_cpu_mem_usage: image_projection.load_state_dict(updated_state_dict, strict=True) else: load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) return image_projection def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False): from ..models.attention_processor import ( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) if low_cpu_mem_usage: if is_accelerate_available(): from accelerate import init_empty_weights else: low_cpu_mem_usage = False logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" " install accelerate\n```\n." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." ) # set ip-adapter cross-attention processors & load state_dict attn_procs = {} key_id = 1 init_context = init_empty_weights if low_cpu_mem_usage else nullcontext for name in self.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = self.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(self.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = self.config.block_out_channels[block_id] if cross_attention_dim is None or "motion_modules" in name: attn_processor_class = self.attn_processors[name].__class__ attn_procs[name] = attn_processor_class() else: attn_processor_class = ( IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor ) num_image_text_embeds = [] for state_dict in state_dicts: if "proj.weight" in state_dict["image_proj"]: # IP-Adapter num_image_text_embeds += [4] elif "proj.3.weight" in state_dict["image_proj"]: # IP-Adapter Full Face num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]: # IP-Adapter Face ID Plus num_image_text_embeds += [4] elif "norm.weight" in state_dict["image_proj"]: # IP-Adapter Face ID num_image_text_embeds += [4] else: # IP-Adapter Plus num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] with init_context(): attn_procs[name] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=num_image_text_embeds, ) value_dict = {} for i, state_dict in enumerate(state_dicts): value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) if not low_cpu_mem_usage: attn_procs[name].load_state_dict(value_dict) else: device = next(iter(value_dict.values())).device dtype = next(iter(value_dict.values())).dtype load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) key_id += 2 return attn_procs def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): if not isinstance(state_dicts, list): state_dicts = [state_dicts] # Kolors Unet already has a `encoder_hid_proj` if ( self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj" and not hasattr(self, "text_encoder_hid_proj") ): self.text_encoder_hid_proj = self.encoder_hid_proj # Set encoder_hid_proj after loading ip_adapter weights, # because `IPAdapterPlusImageProjection` also has `attn_processors`. self.encoder_hid_proj = None attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) self.set_attn_processor(attn_procs) # convert IP-Adapter Image Projection layers to diffusers image_projection_layers = [] for state_dict in state_dicts: image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage ) image_projection_layers.append(image_projection_layer) self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) self.config.encoder_hid_dim_type = "ip_image_proj" self.to(dtype=self.dtype, device=self.device) def _load_ip_adapter_loras(self, state_dicts): lora_dicts = {} for key_id, name in enumerate(self.attn_processors.keys()): for i, state_dict in enumerate(state_dicts): if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: if i not in lora_dicts: lora_dicts[i] = {} lora_dicts[i].update( { f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][ f"{key_id}.to_k_lora.down.weight" ] } ) lora_dicts[i].update( { f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][ f"{key_id}.to_q_lora.down.weight" ] } ) lora_dicts[i].update( { f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][ f"{key_id}.to_v_lora.down.weight" ] } ) lora_dicts[i].update( { f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ f"{key_id}.to_out_lora.down.weight" ] } ) lora_dicts[i].update( {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]} ) lora_dicts[i].update( {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]} ) lora_dicts[i].update( {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]} ) lora_dicts[i].update( { f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][ f"{key_id}.to_out_lora.up.weight" ] } ) return lora_dicts ================================================ FILE: src/diffusers/loaders/unet_loader_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from typing import TYPE_CHECKING, Dict, List, Union from ..utils import logging if TYPE_CHECKING: # import here to avoid circular imports from ..models import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name def _translate_into_actual_layer_name(name): """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')""" if name == "mid": return "mid_block.attentions.0" updown, block, attn = name.split(".") updown = updown.replace("down", "down_blocks").replace("up", "up_blocks") block = block.replace("block_", "") attn = "attentions." + attn return ".".join((updown, block, attn)) def _maybe_expand_lora_scales( unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0 ): blocks_with_transformer = { "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")], "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")], } transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1} expanded_weight_scales = [ _maybe_expand_lora_scales_for_one_adapter( weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict(), default_scale=default_scale, ) for weight_for_adapter in weight_scales ] return expanded_weight_scales def _maybe_expand_lora_scales_for_one_adapter( scales: Union[float, Dict], blocks_with_transformer: Dict[str, int], transformer_per_block: Dict[str, int], state_dict: None, default_scale: float = 1.0, ): """ Expands the inputs into a more granular dictionary. See the example below for more details. Parameters: scales (`Union[float, Dict]`): Scales dict to expand. blocks_with_transformer (`Dict[str, int]`): Dict with keys 'up' and 'down', showing which blocks have transformer layers transformer_per_block (`Dict[str, int]`): Dict with keys 'up' and 'down', showing how many transformer layers each block has E.g. turns ```python scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}} blocks_with_transformer = {"down": [1, 2], "up": [0, 1]} transformer_per_block = {"down": 2, "up": 3} ``` into ```python { "down.block_1.0": 2, "down.block_1.1": 2, "down.block_2.0": 2, "down.block_2.1": 2, "mid": 3, "up.block_0.0": 4, "up.block_0.1": 4, "up.block_0.2": 4, "up.block_1.0": 5, "up.block_1.1": 6, "up.block_1.2": 7, } ``` """ if sorted(blocks_with_transformer.keys()) != ["down", "up"]: raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`") if sorted(transformer_per_block.keys()) != ["down", "up"]: raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`") if not isinstance(scales, dict): # don't expand if scales is a single number return scales scales = copy.deepcopy(scales) if "mid" not in scales: scales["mid"] = default_scale elif isinstance(scales["mid"], list): if len(scales["mid"]) == 1: scales["mid"] = scales["mid"][0] else: raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.") for updown in ["up", "down"]: if updown not in scales: scales[updown] = default_scale # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}} if not isinstance(scales[updown], dict): scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]} # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}} for i in blocks_with_transformer[updown]: block = f"block_{i}" # set not assigned blocks to default scale if block not in scales[updown]: scales[updown][block] = default_scale if not isinstance(scales[updown][block], list): scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])] elif len(scales[updown][block]) == 1: # a list specifying scale to each masked IP input scales[updown][block] = scales[updown][block] * transformer_per_block[updown] elif len(scales[updown][block]) != transformer_per_block[updown]: raise ValueError( f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}." ) # eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1} for i in blocks_with_transformer[updown]: block = f"block_{i}" for tf_idx, value in enumerate(scales[updown][block]): scales[f"{updown}.{block}.{tf_idx}"] = value del scales[updown] for layer in scales.keys(): if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()): raise ValueError( f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions." ) return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()} ================================================ FILE: src/diffusers/loaders/utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict import torch class AttnProcsLayers(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): super().__init__() self.layers = torch.nn.ModuleList(state_dict.values()) self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} # .processor for unet, .self_attn for text encoder self.split_keys = [".processor", ".self_attn"] # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` def map_to(module, state_dict, *args, **kwargs): new_state_dict = {} for key, value in state_dict.items(): num = int(key.split(".")[1]) # 0 is always "layers" new_key = key.replace(f"layers.{num}", module.mapping[num]) new_state_dict[new_key] = value return new_state_dict def remap_key(key, state_dict): for k in self.split_keys: if k in key: return key.split(k)[0] + k raise ValueError( f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." ) def map_from(module, state_dict, *args, **kwargs): all_keys = list(state_dict.keys()) for key in all_keys: replace_key = remap_key(key, state_dict) new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") state_dict[new_key] = state_dict[key] del state_dict[key] self._register_state_dict_hook(map_to) self._register_load_state_dict_pre_hook(map_from, with_module=True) ================================================ FILE: src/diffusers/models/README.md ================================================ # Models For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview). ================================================ FILE: src/diffusers/models/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import TYPE_CHECKING from ..utils import ( DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available, ) _import_structure = {} if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["controlnet"] = ["ControlNetModel"] _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"] _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] _import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"] _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"] _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] _import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"] _import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"] _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"] _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] _import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"] _import_structure["unets.uvit_2d"] = ["UVit2DModel"] if is_flax_available(): _import_structure["controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] _import_structure["vae_flax"] = ["FlaxAutoencoderKL"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .adapter import MultiAdapter, T2IAdapter from .autoencoders import ( AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, ConsistencyDecoderVAE, VQModel, ) from .controlnet import ControlNetModel from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from .controlnet_sparsectrl import SparseControlNetModel from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( AuraFlowTransformer2DModel, CogVideoXTransformer3DModel, DiTTransformer2DModel, DualTransformer2DModel, FluxTransformer2DModel, HunyuanDiT2DModel, LatteTransformer3DModel, LuminaNextDiT2DModel, PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, StableAudioDiTModel, T5FilmDecoder, Transformer2DModel, TransformerTemporalModel, ) from .unets import ( I2VGenXLUNet, Kandinsky3UNet, MotionAdapter, StableCascadeUNet, UNet1DModel, UNet2DConditionModel, UNet2DModel, UNet3DConditionModel, UNetMotionModel, UNetSpatioTemporalConditionModel, UVit2DModel, ) if is_flax_available(): from .controlnet_flax import FlaxControlNetModel from .unets import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL else: import sys sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) ================================================ FILE: src/diffusers/models/activations.py ================================================ # coding=utf-8 # Copyright 2024 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn.functional as F from torch import nn from ..utils import deprecate from ..utils.import_utils import is_torch_npu_available if is_torch_npu_available(): import torch_npu ACTIVATION_FUNCTIONS = { "swish": nn.SiLU(), "silu": nn.SiLU(), "mish": nn.Mish(), "gelu": nn.GELU(), "relu": nn.ReLU(), } def get_activation(act_fn: str) -> nn.Module: """Helper function to get activation function from string. Args: act_fn (str): Name of activation function. Returns: nn.Module: Activation function. """ act_fn = act_fn.lower() if act_fn in ACTIVATION_FUNCTIONS: return ACTIVATION_FUNCTIONS[act_fn] else: raise ValueError(f"Unsupported activation function: {act_fn}") class FP32SiLU(nn.Module): r""" SiLU activation function with input upcasted to torch.float32. """ def __init__(self): super().__init__() def forward(self, inputs: torch.Tensor) -> torch.Tensor: return F.silu(inputs.float(), inplace=False).to(inputs.dtype) class GELU(nn.Module): r""" GELU activation function with tanh approximation support with `approximate="tanh"`. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): super().__init__() self.proj = nn.Linear(dim_in, dim_out, bias=bias) self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": return F.gelu(gate, approximate=self.approximate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states = self.gelu(hidden_states) return hidden_states class GEGLU(nn.Module): r""" A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": return F.gelu(gate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states, *args, **kwargs): if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) hidden_states = self.proj(hidden_states) if is_torch_npu_available(): # using torch_npu.npu_geglu can run faster and save memory on NPU. return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] else: hidden_states, gate = hidden_states.chunk(2, dim=-1) return hidden_states * self.gelu(gate) class SwiGLU(nn.Module): r""" A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU` but uses SiLU / Swish instead of GeLU. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) self.activation = nn.SiLU() def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states, gate = hidden_states.chunk(2, dim=-1) return hidden_states * self.activation(gate) class ApproximateGELU(nn.Module): r""" The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this [paper](https://arxiv.org/abs/1606.08415). Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() self.proj = nn.Linear(dim_in, dim_out, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x * torch.sigmoid(1.702 * x) ================================================ FILE: src/diffusers/models/adapter.py ================================================ # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Callable, List, Optional, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import logging from .modeling_utils import ModelMixin logger = logging.get_logger(__name__) class MultiAdapter(ModelMixin): r""" MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to user-assigned weighting. This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the model (such as downloading or saving, etc.) Parameters: adapters (`List[T2IAdapter]`, *optional*, defaults to None): A list of `T2IAdapter` model instances. """ def __init__(self, adapters: List["T2IAdapter"]): super(MultiAdapter, self).__init__() self.num_adapter = len(adapters) self.adapters = nn.ModuleList(adapters) if len(adapters) == 0: raise ValueError("Expecting at least one adapter") if len(adapters) == 1: raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`") # The outputs from each adapter are added together with a weight. # This means that the change in dimensions from downsampling must # be the same for all adapters. Inductively, it also means the # downscale_factor and total_downscale_factor must be the same for all # adapters. first_adapter_total_downscale_factor = adapters[0].total_downscale_factor first_adapter_downscale_factor = adapters[0].downscale_factor for idx in range(1, len(adapters)): if ( adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor or adapters[idx].downscale_factor != first_adapter_downscale_factor ): raise ValueError( f"Expecting all adapters to have the same downscaling behavior, but got:\n" f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n" f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n" f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n" f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}" ) self.total_downscale_factor = first_adapter_total_downscale_factor self.downscale_factor = first_adapter_downscale_factor def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]: r""" Args: xs (`torch.Tensor`): (batch, channel, height, width) input images for multiple adapter models concated along dimension 1, `channel` should equal to `num_adapter` * "number of channel of image". adapter_weights (`List[float]`, *optional*, defaults to None): List of floats representing the weight which will be multiply to each adapter's output before adding them together. """ if adapter_weights is None: adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter) else: adapter_weights = torch.tensor(adapter_weights) accume_state = None for x, w, adapter in zip(xs, adapter_weights, self.adapters): features = adapter(x) if accume_state is None: accume_state = features for i in range(len(accume_state)): accume_state[i] = w * accume_state[i] else: for i in range(len(features)): accume_state[i] += w * features[i] return accume_state def save_pretrained( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, save_function: Callable = None, safe_serialization: bool = True, variant: Optional[str] = None, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `[`~models.adapter.MultiAdapter.from_pretrained`]` class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful on distributed training like TPUs when one need to replace `torch.save` by another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). variant (`str`, *optional*): If specified, weights are saved in the format pytorch_model..bin. """ idx = 0 model_path_to_save = save_directory for adapter in self.adapters: adapter.save_pretrained( model_path_to_save, is_main_process=is_main_process, save_function=save_function, safe_serialization=safe_serialization, variant=variant, ) idx += 1 model_path_to_save = model_path_to_save + f"_{idx}" @classmethod def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train the model, you should first set it back in training mode with `model.train()`. The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_path (`os.PathLike`): A path to a *directory* containing model weights saved using [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype will be automatically derived from the model's weights. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. variant (`str`, *optional*): If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is ignored when using `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. """ idx = 0 adapters = [] # load adapter and append to list until no adapter directory exists anymore # first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained` # second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ... model_path_to_load = pretrained_model_path while os.path.isdir(model_path_to_load): adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs) adapters.append(adapter) idx += 1 model_path_to_load = pretrained_model_path + f"_{idx}" logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.") if len(adapters) == 0: raise ValueError( f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." ) return cls(adapters) class T2IAdapter(ModelMixin, ConfigMixin): r""" A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's architecture follows the original implementation of [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97) and [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235). This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the model (such as downloading or saving, etc.) Parameters: in_channels (`int`, *optional*, defaults to 3): Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale image as *control image*. channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will also determine the number of downsample blocks in the Adapter. num_res_blocks (`int`, *optional*, defaults to 2): Number of ResNet blocks in each downsample block. downscale_factor (`int`, *optional*, defaults to 8): A factor that determines the total downscale factor of the Adapter. adapter_type (`str`, *optional*, defaults to `full_adapter`): The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`. """ @register_to_config def __init__( self, in_channels: int = 3, channels: List[int] = [320, 640, 1280, 1280], num_res_blocks: int = 2, downscale_factor: int = 8, adapter_type: str = "full_adapter", ): super().__init__() if adapter_type == "full_adapter": self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor) elif adapter_type == "full_adapter_xl": self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor) elif adapter_type == "light_adapter": self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor) else: raise ValueError( f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or " "'full_adapter_xl' or 'light_adapter'." ) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: r""" This function processes the input tensor `x` through the adapter model and returns a list of feature tensors, each representing information extracted at a different scale from the input. The length of the list is determined by the number of downsample blocks in the Adapter, as specified by the `channels` and `num_res_blocks` parameters during initialization. """ return self.adapter(x) @property def total_downscale_factor(self): return self.adapter.total_downscale_factor @property def downscale_factor(self): """The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are not evenly divisible by the downscale_factor then an exception will be raised. """ return self.adapter.unshuffle.downscale_factor # full adapter class FullAdapter(nn.Module): r""" See [`T2IAdapter`] for more information. """ def __init__( self, in_channels: int = 3, channels: List[int] = [320, 640, 1280, 1280], num_res_blocks: int = 2, downscale_factor: int = 8, ): super().__init__() in_channels = in_channels * downscale_factor**2 self.unshuffle = nn.PixelUnshuffle(downscale_factor) self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1) self.body = nn.ModuleList( [ AdapterBlock(channels[0], channels[0], num_res_blocks), *[ AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True) for i in range(1, len(channels)) ], ] ) self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: r""" This method processes the input tensor `x` through the FullAdapter model and performs operations including pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each capturing information at a different stage of processing within the FullAdapter model. The number of feature tensors in the list is determined by the number of downsample blocks specified during initialization. """ x = self.unshuffle(x) x = self.conv_in(x) features = [] for block in self.body: x = block(x) features.append(x) return features class FullAdapterXL(nn.Module): r""" See [`T2IAdapter`] for more information. """ def __init__( self, in_channels: int = 3, channels: List[int] = [320, 640, 1280, 1280], num_res_blocks: int = 2, downscale_factor: int = 16, ): super().__init__() in_channels = in_channels * downscale_factor**2 self.unshuffle = nn.PixelUnshuffle(downscale_factor) self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1) self.body = [] # blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32] for i in range(len(channels)): if i == 1: self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks)) elif i == 2: self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)) else: self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks)) self.body = nn.ModuleList(self.body) # XL has only one downsampling AdapterBlock. self.total_downscale_factor = downscale_factor * 2 def forward(self, x: torch.Tensor) -> List[torch.Tensor]: r""" This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors. """ x = self.unshuffle(x) x = self.conv_in(x) features = [] for block in self.body: x = block(x) features.append(x) return features class AdapterBlock(nn.Module): r""" An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and `FullAdapterXL` models. Parameters: in_channels (`int`): Number of channels of AdapterBlock's input. out_channels (`int`): Number of channels of AdapterBlock's output. num_res_blocks (`int`): Number of ResNet blocks in the AdapterBlock. down (`bool`, *optional*, defaults to `False`): Whether to perform downsampling on AdapterBlock's input. """ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): super().__init__() self.downsample = None if down: self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True) self.in_conv = None if in_channels != out_channels: self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) self.resnets = nn.Sequential( *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)], ) def forward(self, x: torch.Tensor) -> torch.Tensor: r""" This method takes tensor x as input and performs operations downsampling and convolutional layers if the self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of residual blocks to the input tensor. """ if self.downsample is not None: x = self.downsample(x) if self.in_conv is not None: x = self.in_conv(x) x = self.resnets(x) return x class AdapterResnetBlock(nn.Module): r""" An `AdapterResnetBlock` is a helper model that implements a ResNet-like block. Parameters: channels (`int`): Number of channels of AdapterResnetBlock's input and output. """ def __init__(self, channels: int): super().__init__() self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.act = nn.ReLU() self.block2 = nn.Conv2d(channels, channels, kernel_size=1) def forward(self, x: torch.Tensor) -> torch.Tensor: r""" This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional layer on the input tensor. It returns addition with the input tensor. """ h = self.act(self.block1(x)) h = self.block2(h) return h + x # light adapter class LightAdapter(nn.Module): r""" See [`T2IAdapter`] for more information. """ def __init__( self, in_channels: int = 3, channels: List[int] = [320, 640, 1280], num_res_blocks: int = 4, downscale_factor: int = 8, ): super().__init__() in_channels = in_channels * downscale_factor**2 self.unshuffle = nn.PixelUnshuffle(downscale_factor) self.body = nn.ModuleList( [ LightAdapterBlock(in_channels, channels[0], num_res_blocks), *[ LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True) for i in range(len(channels) - 1) ], LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True), ] ) self.total_downscale_factor = downscale_factor * (2 ** len(channels)) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: r""" This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each feature tensor corresponds to a different level of processing within the LightAdapter. """ x = self.unshuffle(x) features = [] for block in self.body: x = block(x) features.append(x) return features class LightAdapterBlock(nn.Module): r""" A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the `LightAdapter` model. Parameters: in_channels (`int`): Number of channels of LightAdapterBlock's input. out_channels (`int`): Number of channels of LightAdapterBlock's output. num_res_blocks (`int`): Number of LightAdapterResnetBlocks in the LightAdapterBlock. down (`bool`, *optional*, defaults to `False`): Whether to perform downsampling on LightAdapterBlock's input. """ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False): super().__init__() mid_channels = out_channels // 4 self.downsample = None if down: self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True) self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1) self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)]) self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1) def forward(self, x: torch.Tensor) -> torch.Tensor: r""" This method takes tensor x as input and performs downsampling if required. Then it applies in convolution layer, a sequence of residual blocks, and out convolutional layer. """ if self.downsample is not None: x = self.downsample(x) x = self.in_conv(x) x = self.resnets(x) x = self.out_conv(x) return x class LightAdapterResnetBlock(nn.Module): """ A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different architecture than `AdapterResnetBlock`. Parameters: channels (`int`): Number of channels of LightAdapterResnetBlock's input and output. """ def __init__(self, channels: int): super().__init__() self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.act = nn.ReLU() self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: r""" This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and another convolutional layer and adds it to input tensor. """ h = self.act(self.block1(x)) h = self.block2(h) return h + x ================================================ FILE: src/diffusers/models/attention.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm logger = logging.get_logger(__name__) def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): # "feed_forward_chunk_size" can be used to save memory if hidden_states.shape[chunk_dim] % chunk_size != 0: raise ValueError( f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) num_chunks = hidden_states.shape[chunk_dim] // chunk_size ff_output = torch.cat( [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], dim=chunk_dim, ) return ff_output @maybe_allow_in_graph class GatedSelfAttentionDense(nn.Module): r""" A gated self-attention dense layer that combines visual features and object features. Parameters: query_dim (`int`): The number of channels in the query. context_dim (`int`): The number of channels in the context. n_heads (`int`): The number of heads to use for attention. d_head (`int`): The number of channels in each head. """ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): super().__init__() # we need a linear projection since we need cat visual feature and obj feature self.linear = nn.Linear(context_dim, query_dim) self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) self.ff = FeedForward(query_dim, activation_fn="geglu") self.norm1 = nn.LayerNorm(query_dim) self.norm2 = nn.LayerNorm(query_dim) self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) self.enabled = True def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: if not self.enabled: return x n_visual = x.shape[1] objs = self.linear(objs) x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) return x @maybe_allow_in_graph class JointTransformerBlock(nn.Module): r""" A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the processing of `context` conditions. """ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False): super().__init__() self.context_pre_only = context_pre_only context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" self.norm1 = AdaLayerNormZero(dim) if context_norm_type == "ada_norm_continous": self.norm1_context = AdaLayerNormContinuous( dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" ) elif context_norm_type == "ada_norm_zero": self.norm1_context = AdaLayerNormZero(dim) else: raise ValueError( f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" ) if hasattr(F, "scaled_dot_product_attention"): processor = JointAttnProcessor2_0() else: raise ValueError( "The current PyTorch version does not support the `scaled_dot_product_attention` function." ) self.attn = Attention( query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=context_pre_only, bias=True, processor=processor, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") if not context_pre_only: self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") else: self.norm2_context = None self.ff_context = None # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor ): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) if self.context_pre_only: norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) else: norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states ) # Process attention outputs for the `hidden_states`. attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output # Process attention outputs for the `encoder_hidden_states`. if self.context_pre_only: encoder_hidden_states = None else: context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory context_ff_output = _chunked_feed_forward( self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size ) else: context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output return encoder_hidden_states, hidden_states @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. upcast_attention (`bool`, *optional*): Whether to upcast the attention computation to float32. This is useful for mixed precision training. norm_elementwise_affine (`bool`, *optional*, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. norm_type (`str`, *optional*, defaults to `"layer_norm"`): The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. final_dropout (`bool` *optional*, defaults to False): Whether to apply a final dropout after the last feed-forward layer. attention_type (`str`, *optional*, defaults to `"default"`): The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. positional_embeddings (`str`, *optional*, defaults to `None`): The type of positional embeddings to apply to. num_positional_embeddings (`int`, *optional*, defaults to `None`): The maximum number of positional embeddings to apply. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' norm_eps: float = 1e-5, final_dropout: bool = False, attention_type: str = "default", positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, ada_norm_bias: Optional[int] = None, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() self.dim = dim self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim self.dropout = dropout self.cross_attention_dim = cross_attention_dim self.activation_fn = activation_fn self.attention_bias = attention_bias self.double_self_attention = double_self_attention self.norm_elementwise_affine = norm_elementwise_affine self.positional_embeddings = positional_embeddings self.num_positional_embeddings = num_positional_embeddings self.only_cross_attention = only_cross_attention # We keep these boolean flags for backward-compatibility. self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.use_ada_layer_norm_single = norm_type == "ada_norm_single" self.use_layer_norm = norm_type == "layer_norm" self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) self.norm_type = norm_type self.num_embeds_ada_norm = num_embeds_ada_norm if positional_embeddings and (num_positional_embeddings is None): raise ValueError( "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." ) if positional_embeddings == "sinusoidal": self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) else: self.pos_embed = None # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn if norm_type == "ada_norm": self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) elif norm_type == "ada_norm_zero": self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) elif norm_type == "ada_norm_continuous": self.norm1 = AdaLayerNormContinuous( dim, ada_norm_continous_conditioning_embedding_dim, norm_elementwise_affine, norm_eps, ada_norm_bias, "rms_norm", ) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, out_bias=attention_out_bias, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. if norm_type == "ada_norm": self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) elif norm_type == "ada_norm_continuous": self.norm2 = AdaLayerNormContinuous( dim, ada_norm_continous_conditioning_embedding_dim, norm_elementwise_affine, norm_eps, ada_norm_bias, "rms_norm", ) else: self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, out_bias=attention_out_bias, ) # is self-attn if encoder_hidden_states is none else: if norm_type == "ada_norm_single": # For Latte self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) else: self.norm2 = None self.attn2 = None # 3. Feed-forward if norm_type == "ada_norm_continuous": self.norm3 = AdaLayerNormContinuous( dim, ada_norm_continous_conditioning_embedding_dim, norm_elementwise_affine, norm_eps, ada_norm_bias, "layer_norm", ) elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) elif norm_type == "layer_norm_i2vgen": self.norm3 = None self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) # 4. Fuser if attention_type == "gated" or attention_type == "gated-text-image": self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) # 5. Scale-shift for PixArt-Alpha. if norm_type == "ada_norm_single": self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention batch_size = hidden_states.shape[0] if self.norm_type == "ada_norm": norm_hidden_states = self.norm1(hidden_states, timestep) elif self.norm_type == "ada_norm_zero": norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: norm_hidden_states = self.norm1(hidden_states) elif self.norm_type == "ada_norm_continuous": norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) elif self.norm_type == "ada_norm_single": shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) ).chunk(6, dim=1) norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa else: raise ValueError("Incorrect norm used") if self.pos_embed is not None: norm_hidden_states = self.pos_embed(norm_hidden_states) # 1. Prepare GLIGEN inputs cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} gligen_kwargs = cross_attention_kwargs.pop("gligen", None) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) if self.norm_type == "ada_norm_zero": attn_output = gate_msa.unsqueeze(1) * attn_output elif self.norm_type == "ada_norm_single": attn_output = gate_msa * attn_output hidden_states = attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) # 1.2 GLIGEN Control if gligen_kwargs is not None: hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) # 3. Cross-Attention if self.attn2 is not None: if self.norm_type == "ada_norm": norm_hidden_states = self.norm2(hidden_states, timestep) elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: norm_hidden_states = self.norm2(hidden_states) elif self.norm_type == "ada_norm_single": # For PixArt norm2 isn't applied here: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 norm_hidden_states = hidden_states elif self.norm_type == "ada_norm_continuous": norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) else: raise ValueError("Incorrect norm") if self.pos_embed is not None and self.norm_type != "ada_norm_single": norm_hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 4. Feed-forward # i2vgen doesn't have this norm 🤷‍♂️ if self.norm_type == "ada_norm_continuous": norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) elif not self.norm_type == "ada_norm_single": norm_hidden_states = self.norm3(hidden_states) if self.norm_type == "ada_norm_zero": norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self.norm_type == "ada_norm_single": norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) if self.norm_type == "ada_norm_zero": ff_output = gate_mlp.unsqueeze(1) * ff_output elif self.norm_type == "ada_norm_single": ff_output = gate_mlp * ff_output hidden_states = ff_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) return hidden_states class LuminaFeedForward(nn.Module): r""" A feed-forward layer. Parameters: hidden_size (`int`): The dimensionality of the hidden layers in the model. This parameter determines the width of the model's hidden representations. intermediate_size (`int`): The intermediate dimension of the feedforward layer. multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple of this value. ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden dimension. Defaults to None. """ def __init__( self, dim: int, inner_dim: int, multiple_of: Optional[int] = 256, ffn_dim_multiplier: Optional[float] = None, ): super().__init__() inner_dim = int(2 * inner_dim / 3) # custom hidden_size factor multiplier if ffn_dim_multiplier is not None: inner_dim = int(ffn_dim_multiplier * inner_dim) inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) self.linear_1 = nn.Linear( dim, inner_dim, bias=False, ) self.linear_2 = nn.Linear( inner_dim, dim, bias=False, ) self.linear_3 = nn.Linear( dim, inner_dim, bias=False, ) self.silu = FP32SiLU() def forward(self, x): return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x)) @maybe_allow_in_graph class TemporalBasicTransformerBlock(nn.Module): r""" A basic Transformer block for video like data. Parameters: dim (`int`): The number of channels in the input and output. time_mix_inner_dim (`int`): The number of channels for temporal attention. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. """ def __init__( self, dim: int, time_mix_inner_dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: Optional[int] = None, ): super().__init__() self.is_res = dim == time_mix_inner_dim self.norm_in = nn.LayerNorm(dim) # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn self.ff_in = FeedForward( dim, dim_out=time_mix_inner_dim, activation_fn="geglu", ) self.norm1 = nn.LayerNorm(time_mix_inner_dim) self.attn1 = Attention( query_dim=time_mix_inner_dim, heads=num_attention_heads, dim_head=attention_head_dim, cross_attention_dim=None, ) # 2. Cross-Attn if cross_attention_dim is not None: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = nn.LayerNorm(time_mix_inner_dim) self.attn2 = Attention( query_dim=time_mix_inner_dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(time_mix_inner_dim) self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") # let chunk size default to None self._chunk_size = None self._chunk_dim = None def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): # Sets chunk feed-forward self._chunk_size = chunk_size # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off self._chunk_dim = 1 def forward( self, hidden_states: torch.Tensor, num_frames: int, encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention batch_size = hidden_states.shape[0] batch_frames, seq_length, channels = hidden_states.shape batch_size = batch_frames // num_frames hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) hidden_states = hidden_states.permute(0, 2, 1, 3) hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) residual = hidden_states hidden_states = self.norm_in(hidden_states) if self._chunk_size is not None: hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) else: hidden_states = self.ff_in(hidden_states) if self.is_res: hidden_states = hidden_states + residual norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) hidden_states = attn_output + hidden_states # 3. Cross-Attention if self.attn2 is not None: norm_hidden_states = self.norm2(hidden_states) attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) hidden_states = attn_output + hidden_states # 4. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self._chunk_size is not None: ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) if self.is_res: hidden_states = ff_output + hidden_states else: hidden_states = ff_output hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) hidden_states = hidden_states.permute(0, 2, 1, 3) hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) return hidden_states class SkipFFTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, kv_input_dim: int, kv_input_dim_proj_use_bias: bool, dropout=0.0, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, attention_out_bias: bool = True, ): super().__init__() if kv_input_dim != dim: self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) else: self.kv_mapper = None self.norm1 = RMSNorm(dim, 1e-06) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim, out_bias=attention_out_bias, ) self.norm2 = RMSNorm(dim, 1e-06) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, out_bias=attention_out_bias, ) def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} if self.kv_mapper is not None: encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states norm_hidden_states = self.norm2(hidden_states) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states return hidden_states @maybe_allow_in_graph class FreeNoiseTransformerBlock(nn.Module): r""" A FreeNoise Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (`int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (`bool`, defaults to `False`): Configure if the attentions should contain a bias parameter. only_cross_attention (`bool`, defaults to `False`): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, defaults to `False`): Whether to use two self-attention layers. In this case no cross attention layers are used. upcast_attention (`bool`, defaults to `False`): Whether to upcast the attention computation to float32. This is useful for mixed precision training. norm_elementwise_affine (`bool`, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. norm_type (`str`, defaults to `"layer_norm"`): The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. final_dropout (`bool` defaults to `False`): Whether to apply a final dropout after the last feed-forward layer. attention_type (`str`, defaults to `"default"`): The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. positional_embeddings (`str`, *optional*): The type of positional embeddings to apply to. num_positional_embeddings (`int`, *optional*, defaults to `None`): The maximum number of positional embeddings to apply. ff_inner_dim (`int`, *optional*): Hidden dimension of feed-forward MLP. ff_bias (`bool`, defaults to `True`): Whether or not to use bias in feed-forward MLP. attention_out_bias (`bool`, defaults to `True`): Whether or not to use bias in attention output project layer. context_length (`int`, defaults to `16`): The maximum number of frames that the FreeNoise block processes at once. context_stride (`int`, defaults to `4`): The number of frames to be skipped before starting to process a new batch of `context_length` frames. weighting_scheme (`str`, defaults to `"pyramid"`): The weighting scheme to use for weighting averaging of processed latent frames. As described in the Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting used. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout: float = 0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", norm_eps: float = 1e-5, final_dropout: bool = False, positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, context_length: int = 16, context_stride: int = 4, weighting_scheme: str = "pyramid", ): super().__init__() self.dim = dim self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim self.dropout = dropout self.cross_attention_dim = cross_attention_dim self.activation_fn = activation_fn self.attention_bias = attention_bias self.double_self_attention = double_self_attention self.norm_elementwise_affine = norm_elementwise_affine self.positional_embeddings = positional_embeddings self.num_positional_embeddings = num_positional_embeddings self.only_cross_attention = only_cross_attention self.set_free_noise_properties(context_length, context_stride, weighting_scheme) # We keep these boolean flags for backward-compatibility. self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.use_ada_layer_norm_single = norm_type == "ada_norm_single" self.use_layer_norm = norm_type == "layer_norm" self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) self.norm_type = norm_type self.num_embeds_ada_norm = num_embeds_ada_norm if positional_embeddings and (num_positional_embeddings is None): raise ValueError( "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." ) if positional_embeddings == "sinusoidal": self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) else: self.pos_embed = None # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, out_bias=attention_out_bias, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, out_bias=attention_out_bias, ) # is self-attn if encoder_hidden_states is none # 3. Feed-forward self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: frame_indices = [] for i in range(0, num_frames - self.context_length + 1, self.context_stride): window_start = i window_end = min(num_frames, i + self.context_length) frame_indices.append((window_start, window_end)) return frame_indices def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: if weighting_scheme == "pyramid": if num_frames % 2 == 0: # num_frames = 4 => [1, 2, 2, 1] weights = list(range(1, num_frames // 2 + 1)) weights = weights + weights[::-1] else: # num_frames = 5 => [1, 2, 3, 2, 1] weights = list(range(1, num_frames // 2 + 1)) weights = weights + [num_frames // 2 + 1] + weights[::-1] else: raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") return weights def set_free_noise_properties( self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid" ) -> None: self.context_length = context_length self.context_stride = context_stride self.weighting_scheme = weighting_scheme def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None: # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Dict[str, Any] = None, *args, **kwargs, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} # hidden_states: [B x H x W, F, C] device = hidden_states.device dtype = hidden_states.dtype num_frames = hidden_states.size(1) frame_indices = self._get_frame_indices(num_frames) frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme) frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) is_last_frame_batch_complete = frame_indices[-1][1] == num_frames # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges: # [(0, 16), (4, 20), (8, 24), (10, 26)] if not is_last_frame_batch_complete: if num_frames < self.context_length: raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}") last_frame_batch_length = num_frames - frame_indices[-1][1] frame_indices.append((num_frames - self.context_length, num_frames)) num_times_accumulated = torch.zeros((1, num_frames, 1), device=device) accumulated_values = torch.zeros_like(hidden_states) for i, (frame_start, frame_end) in enumerate(frame_indices): # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or # essentially a non-multiple of `context_length`. weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end]) weights *= frame_weights hidden_states_chunk = hidden_states[:, frame_start:frame_end] # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention norm_hidden_states = self.norm1(hidden_states_chunk) if self.pos_embed is not None: norm_hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) hidden_states_chunk = attn_output + hidden_states_chunk if hidden_states_chunk.ndim == 4: hidden_states_chunk = hidden_states_chunk.squeeze(1) # 2. Cross-Attention if self.attn2 is not None: norm_hidden_states = self.norm2(hidden_states_chunk) if self.pos_embed is not None and self.norm_type != "ada_norm_single": norm_hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states_chunk = attn_output + hidden_states_chunk if i == len(frame_indices) - 1 and not is_last_frame_batch_complete: accumulated_values[:, -last_frame_batch_length:] += ( hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:] ) num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length] else: accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights num_times_accumulated[:, frame_start:frame_end] += weights hidden_states = torch.where( num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values ).to(dtype) # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self._chunk_size is not None: ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) hidden_states = ff_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) return hidden_states class FeedForward(nn.Module): r""" A feed-forward layer. Parameters: dim (`int`): The number of channels in the input. dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__( self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, inner_dim=None, bias: bool = True, ): super().__init__() if inner_dim is None: inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim if activation_fn == "gelu": act_fn = GELU(dim, inner_dim, bias=bias) if activation_fn == "gelu-approximate": act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) elif activation_fn == "geglu": act_fn = GEGLU(dim, inner_dim, bias=bias) elif activation_fn == "geglu-approximate": act_fn = ApproximateGELU(dim, inner_dim, bias=bias) elif activation_fn == "swiglu": act_fn = SwiGLU(dim, inner_dim, bias=bias) self.net = nn.ModuleList([]) # project in self.net.append(act_fn) # project dropout self.net.append(nn.Dropout(dropout)) # project out self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for module in self.net: hidden_states = module(hidden_states) return hidden_states ================================================ FILE: src/diffusers/models/attention_flax.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import math import flax.linen as nn import jax import jax.numpy as jnp def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" num_kv, num_heads, k_features = key.shape[-3:] v_features = value.shape[-1] key_chunk_size = min(key_chunk_size, num_kv) query = query / jnp.sqrt(k_features) @functools.partial(jax.checkpoint, prevent_cse=False) def summarize_chunk(query, key, value): attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) max_score = jnp.max(attn_weights, axis=-1, keepdims=True) max_score = jax.lax.stop_gradient(max_score) exp_weights = jnp.exp(attn_weights - max_score) exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) max_score = jnp.einsum("...qhk->...qh", max_score) return (exp_values, exp_weights.sum(axis=-1), max_score) def chunk_scanner(chunk_idx): # julienne key array key_chunk = jax.lax.dynamic_slice( operand=key, start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] ) # julienne value array value_chunk = jax.lax.dynamic_slice( operand=value, start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] ) return summarize_chunk(query, key_chunk, value_chunk) chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) global_max = jnp.max(chunk_max, axis=0, keepdims=True) max_diffs = jnp.exp(chunk_max - global_max) chunk_values *= jnp.expand_dims(max_diffs, axis=-1) chunk_weights *= max_diffs all_values = chunk_values.sum(axis=0) all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) return all_values / all_weights def jax_memory_efficient_attention( query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 ): r""" Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 https://github.com/AminRezaei0x443/memory-efficient-attention Args: query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): numerical precision for computation query_chunk_size (`int`, *optional*, defaults to 1024): chunk size to divide query array value must divide query_length equally without remainder key_chunk_size (`int`, *optional*, defaults to 4096): chunk size to divide key and value array value must divide key_value_length equally without remainder Returns: (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) """ num_q, num_heads, q_features = query.shape[-3:] def chunk_scanner(chunk_idx, _): # julienne query array query_chunk = jax.lax.dynamic_slice( operand=query, start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] ) return ( chunk_idx + query_chunk_size, # unused ignore it _query_chunk_attention( query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size ), ) _, res = jax.lax.scan( f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size), # start counter # stop counter ) return jnp.concatenate(res, axis=-3) # fuse the chunked result back class FlaxAttention(nn.Module): r""" A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 Parameters: query_dim (:obj:`int`): Input hidden states dimension heads (:obj:`int`, *optional*, defaults to 8): Number of heads dim_head (:obj:`int`, *optional*, defaults to 64): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 split_head_dim (`bool`, *optional*, defaults to `False`): Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ query_dim: int heads: int = 8 dim_head: int = 64 dropout: float = 0.0 use_memory_efficient_attention: bool = False split_head_dim: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): inner_dim = self.dim_head * self.heads self.scale = self.dim_head**-0.5 # Weights were exported with old names {to_q, to_k, to_v, to_out} self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") self.dropout_layer = nn.Dropout(rate=self.dropout) def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def __call__(self, hidden_states, context=None, deterministic=True): context = hidden_states if context is None else context query_proj = self.query(hidden_states) key_proj = self.key(context) value_proj = self.value(context) if self.split_head_dim: b = hidden_states.shape[0] query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head)) key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head)) value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head)) else: query_states = self.reshape_heads_to_batch_dim(query_proj) key_states = self.reshape_heads_to_batch_dim(key_proj) value_states = self.reshape_heads_to_batch_dim(value_proj) if self.use_memory_efficient_attention: query_states = query_states.transpose(1, 0, 2) key_states = key_states.transpose(1, 0, 2) value_states = value_states.transpose(1, 0, 2) # this if statement create a chunk size for each layer of the unet # the chunk size is equal to the query_length dimension of the deepest layer of the unet flatten_latent_dim = query_states.shape[-3] if flatten_latent_dim % 64 == 0: query_chunk_size = int(flatten_latent_dim / 64) elif flatten_latent_dim % 16 == 0: query_chunk_size = int(flatten_latent_dim / 16) elif flatten_latent_dim % 4 == 0: query_chunk_size = int(flatten_latent_dim / 4) else: query_chunk_size = int(flatten_latent_dim) hidden_states = jax_memory_efficient_attention( query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 ) hidden_states = hidden_states.transpose(1, 0, 2) else: # compute attentions if self.split_head_dim: attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states) else: attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) attention_scores = attention_scores * self.scale attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2) # attend to values if self.split_head_dim: hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states) b = hidden_states.shape[0] hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head)) else: hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.proj_attn(hidden_states) return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxBasicTransformerBlock(nn.Module): r""" A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: https://arxiv.org/abs/1706.03762 Parameters: dim (:obj:`int`): Inner hidden states dimension n_heads (:obj:`int`): Number of heads d_head (:obj:`int`): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate only_cross_attention (`bool`, defaults to `False`): Whether to only apply cross attention. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 split_head_dim (`bool`, *optional*, defaults to `False`): Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. """ dim: int n_heads: int d_head: int dropout: float = 0.0 only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 use_memory_efficient_attention: bool = False split_head_dim: bool = False def setup(self): # self attention (or cross_attention if only_cross_attention is True) self.attn1 = FlaxAttention( self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, self.split_head_dim, dtype=self.dtype, ) # cross attention self.attn2 = FlaxAttention( self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, self.split_head_dim, dtype=self.dtype, ) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states if self.only_cross_attention: hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) else: hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual # cross attention residual = hidden_states hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) hidden_states = hidden_states + residual # feed forward residual = hidden_states hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxTransformer2DModel(nn.Module): r""" A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: https://arxiv.org/pdf/1506.02025.pdf Parameters: in_channels (:obj:`int`): Input number of channels n_heads (:obj:`int`): Number of heads d_head (:obj:`int`): Hidden states dimension inside each head depth (:obj:`int`, *optional*, defaults to 1): Number of transformers block dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate use_linear_projection (`bool`, defaults to `False`): tbd only_cross_attention (`bool`, defaults to `False`): tbd dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 split_head_dim (`bool`, *optional*, defaults to `False`): Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. """ in_channels: int n_heads: int d_head: int depth: int = 1 dropout: float = 0.0 use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 use_memory_efficient_attention: bool = False split_head_dim: bool = False def setup(self): self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) inner_dim = self.n_heads * self.d_head if self.use_linear_projection: self.proj_in = nn.Dense(inner_dim, dtype=self.dtype) else: self.proj_in = nn.Conv( inner_dim, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.transformer_blocks = [ FlaxBasicTransformerBlock( inner_dim, self.n_heads, self.d_head, dropout=self.dropout, only_cross_attention=self.only_cross_attention, dtype=self.dtype, use_memory_efficient_attention=self.use_memory_efficient_attention, split_head_dim=self.split_head_dim, ) for _ in range(self.depth) ] if self.use_linear_projection: self.proj_out = nn.Dense(inner_dim, dtype=self.dtype) else: self.proj_out = nn.Conv( inner_dim, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height * width, channels) hidden_states = self.proj_in(hidden_states) else: hidden_states = self.proj_in(hidden_states) hidden_states = hidden_states.reshape(batch, height * width, channels) for transformer_block in self.transformer_blocks: hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) if self.use_linear_projection: hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, channels) else: hidden_states = hidden_states.reshape(batch, height, width, channels) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxFeedForward(nn.Module): r""" Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's [`FeedForward`] class, with the following simplifications: - The activation function is currently hardcoded to a gated linear unit from: https://arxiv.org/abs/2002.05202 - `dim_out` is equal to `dim`. - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`]. Parameters: dim (:obj:`int`): Inner hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): # The second linear layer needs to be called # net_2 for now to match the index of the Sequential layer self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) self.net_2 = nn.Dense(self.dim, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): hidden_states = self.net_0(hidden_states, deterministic=deterministic) hidden_states = self.net_2(hidden_states) return hidden_states class FlaxGEGLU(nn.Module): r""" Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: dim (:obj:`int`): Input hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): inner_dim = self.dim * 4 self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, deterministic=True): hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic) ================================================ FILE: src/diffusers/models/attention_processor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import math from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, logging from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_torch_npu_available(): import torch_npu if is_xformers_available(): import xformers import xformers.ops else: xformers = None @maybe_allow_in_graph class Attention(nn.Module): r""" A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. kv_heads (`int`, *optional*, defaults to `None`): The number of key and value heads to use for multi-head attention. Defaults to `heads`. If `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. upcast_attention (`bool`, *optional*, defaults to False): Set to `True` to upcast the attention computation to `float32`. upcast_softmax (`bool`, *optional*, defaults to False): Set to `True` to upcast the softmax computation to `float32`. cross_attention_norm (`str`, *optional*, defaults to `None`): The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the group norm in the cross attention. added_kv_proj_dim (`int`, *optional*, defaults to `None`): The number of channels to use for the added key and value projections. If `None`, no projection is used. norm_num_groups (`int`, *optional*, defaults to `None`): The number of groups to use for the group norm in the attention. spatial_norm_dim (`int`, *optional*, defaults to `None`): The number of channels to use for the spatial normalization. out_bias (`bool`, *optional*, defaults to `True`): Set to `True` to use a bias in the output linear layer. scale_qk (`bool`, *optional*, defaults to `True`): Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. only_cross_attention (`bool`, *optional*, defaults to `False`): Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if `added_kv_proj_dim` is not `None`. eps (`float`, *optional*, defaults to 1e-5): An additional value added to the denominator in group normalization that is used for numerical stability. rescale_output_factor (`float`, *optional*, defaults to 1.0): A factor to rescale the output by dividing it with this value. residual_connection (`bool`, *optional*, defaults to `False`): Set to `True` to add the residual connection to the output. _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): Set to `True` if the attention block is loaded from a deprecated state dict. processor (`AttnProcessor`, *optional*, defaults to `None`): The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and `AttnProcessor` otherwise. """ def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, kv_heads: Optional[int] = None, dim_head: int = 64, dropout: float = 0.0, bias: bool = False, upcast_attention: bool = False, upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, added_proj_bias: Optional[bool] = True, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, out_dim: int = None, context_pre_only=None, pre_only=False, ): super().__init__() # To prevent circular import. from .normalization import FP32LayerNorm, RMSNorm self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads self.query_dim = query_dim self.use_bias = bias self.is_cross_attention = cross_attention_dim is not None self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout self.fused_projections = False self.out_dim = out_dim if out_dim is not None else query_dim self.context_pre_only = context_pre_only self.pre_only = pre_only # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = out_dim // dim_head if out_dim is not None else heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self.sliceable_head_dim = heads self.added_kv_proj_dim = added_kv_proj_dim self.only_cross_attention = only_cross_attention if self.added_kv_proj_dim is None and self.only_cross_attention: raise ValueError( "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." ) if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None if spatial_norm_dim is not None: self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) else: self.spatial_norm = None if qk_norm is None: self.norm_q = None self.norm_k = None elif qk_norm == "layer_norm": self.norm_q = nn.LayerNorm(dim_head, eps=eps) self.norm_k = nn.LayerNorm(dim_head, eps=eps) elif qk_norm == "fp32_layer_norm": self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) elif qk_norm == "layer_norm_across_heads": # Lumina applys qk norm across all heads self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) elif qk_norm == "rms_norm": self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) else: raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": self.norm_cross = nn.LayerNorm(self.cross_attention_dim) elif cross_attention_norm == "group_norm": if self.added_kv_proj_dim is not None: # The given `encoder_hidden_states` are initially of shape # (batch_size, seq_len, added_kv_proj_dim) before being projected # to (batch_size, seq_len, cross_attention_dim). The norm is applied # before the projection, so we need to use `added_kv_proj_dim` as # the number of channels for the group norm. norm_cross_num_channels = added_kv_proj_dim else: norm_cross_num_channels = self.cross_attention_dim self.norm_cross = nn.GroupNorm( num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True ) else: raise ValueError( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) else: self.to_k = None self.to_v = None self.added_proj_bias = added_proj_bias if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) if not self.pre_only: self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm == "fp32_layer_norm": self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) elif qk_norm == "rms_norm": self.norm_added_q = RMSNorm(dim_head, eps=eps) self.norm_added_k = RMSNorm(dim_head, eps=eps) else: self.norm_added_q = None self.norm_added_k = None # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: r""" Set whether to use npu flash attention from `torch_npu` or not. """ if use_npu_flash_attention: processor = AttnProcessorNPU() else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ) -> None: r""" Set whether to use memory efficient attention from `xformers` or not. Args: use_memory_efficient_attention_xformers (`bool`): Whether to use memory efficient attention from `xformers` or not. attention_op (`Callable`, *optional*): The attention operation to use. Defaults to `None` which uses the default attention operation from `xformers`. """ is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), ) is_added_kv_processor = hasattr(self, "processor") and isinstance( self.processor, ( AttnAddedKVProcessor, AttnAddedKVProcessor2_0, SlicedAttnAddedKVProcessor, XFormersAttnAddedKVProcessor, ), ) if use_memory_efficient_attention_xformers: if is_added_kv_processor and is_custom_diffusion: raise NotImplementedError( f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}" ) if not is_xformers_available(): raise ModuleNotFoundError( ( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers" ), name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e if is_custom_diffusion: processor = CustomDiffusionXFormersAttnProcessor( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, attention_op=attention_op, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) elif is_added_kv_processor: # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP # which uses this type of cross attention ONLY because the attention mask of format # [0, ..., -10.000, ..., 0, ...,] is not supported # throw warning logger.info( "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." ) processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: if is_custom_diffusion: attn_processor_class = ( CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor ) processor = attn_processor_class( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_attention_slice(self, slice_size: int) -> None: r""" Set the slice size for attention computation. Args: slice_size (`int`): The slice size for attention computation. """ if slice_size is not None and slice_size > self.sliceable_head_dim: raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") if slice_size is not None and self.added_kv_proj_dim is not None: processor = SlicedAttnAddedKVProcessor(slice_size) elif slice_size is not None: processor = SlicedAttnProcessor(slice_size) elif self.added_kv_proj_dim is not None: processor = AttnAddedKVProcessor() else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_processor(self, processor: "AttnProcessor") -> None: r""" Set the attention processor to use. Args: processor (`AttnProcessor`): The attention processor to use. """ # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( hasattr(self, "processor") and isinstance(self.processor, torch.nn.Module) and not isinstance(processor, torch.nn.Module) ): logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") self._modules.pop("processor") self.processor = processor def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": r""" Get the attention processor in use. Args: return_deprecated_lora (`bool`, *optional*, defaults to `False`): Set to `True` to return the deprecated LoRA attention processor. Returns: "AttentionProcessor": The attention processor in use. """ if not return_deprecated_lora: return self.processor def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **cross_attention_kwargs, ) -> torch.Tensor: r""" The forward method of the `Attention` class. Args: hidden_states (`torch.Tensor`): The hidden states of the query. encoder_hidden_states (`torch.Tensor`, *optional*): The hidden states of the encoder. attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. **cross_attention_kwargs: Additional keyword arguments to pass along to the cross attention. Returns: `torch.Tensor`: The output of the attention layer. """ # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) quiet_attn_parameters = {"ip_adapter_masks"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] if len(unused_kwargs) > 0: logger.warning( f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." ) cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: r""" Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` is the number of heads initialized while constructing the `Attention` class. Args: tensor (`torch.Tensor`): The tensor to reshape. Returns: `torch.Tensor`: The reshaped tensor. """ head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: r""" Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is the number of heads initialized while constructing the `Attention` class. Args: tensor (`torch.Tensor`): The tensor to reshape. out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is reshaped to `[batch_size * heads, seq_len, dim // heads]`. Returns: `torch.Tensor`: The reshaped tensor. """ head_size = self.heads if tensor.ndim == 3: batch_size, seq_len, dim = tensor.shape extra_dim = 1 else: batch_size, extra_dim, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3) if out_dim == 3: tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) return tensor def get_attention_scores( self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: r""" Compute the attention scores. Args: query (`torch.Tensor`): The query tensor. key (`torch.Tensor`): The key tensor. attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. Returns: `torch.Tensor`: The attention probabilities/scores. """ dtype = query.dtype if self.upcast_attention: query = query.float() key = key.float() if attention_mask is None: baddbmm_input = torch.empty( query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device ) beta = 0 else: baddbmm_input = attention_mask beta = 1 attention_scores = torch.baddbmm( baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale, ) del baddbmm_input if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) del attention_scores attention_probs = attention_probs.to(dtype) return attention_probs def prepare_attention_mask( self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 ) -> torch.Tensor: r""" Prepare the attention mask for the attention computation. Args: attention_mask (`torch.Tensor`): The attention mask to prepare. target_length (`int`): The target length of the attention mask. This is the length of the attention mask after padding. batch_size (`int`): The batch size, which is used to repeat the attention mask. out_dim (`int`, *optional*, defaults to `3`): The output dimension of the attention mask. Can be either `3` or `4`. Returns: `torch.Tensor`: The prepared attention mask. """ head_size = self.heads if attention_mask is None: return attention_mask current_length: int = attention_mask.shape[-1] if current_length != target_length: if attention_mask.device.type == "mps": # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: # we want to instead pad by (0, remaining_length), where remaining_length is: # remaining_length: int = target_length - current_length # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: attention_mask = attention_mask.repeat_interleave(head_size, dim=0) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.repeat_interleave(head_size, dim=1) return attention_mask def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: r""" Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the `Attention` class. Args: encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. Returns: `torch.Tensor`: The normalized encoder hidden states. """ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" if isinstance(self.norm_cross, nn.LayerNorm): encoder_hidden_states = self.norm_cross(encoder_hidden_states) elif isinstance(self.norm_cross, nn.GroupNorm): # Group norm norms along the channels dimension and expects # input to be in the shape of (N, C, *). In this case, we want # to norm along the hidden dimension, so we need to move # (batch_size, sequence_length, hidden_size) -> # (batch_size, hidden_size, sequence_length) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) encoder_hidden_states = self.norm_cross(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) else: assert False return encoder_hidden_states @torch.no_grad() def fuse_projections(self, fuse=True): device = self.to_q.weight.data.device dtype = self.to_q.weight.data.dtype if not self.is_cross_attention: # fetch weight matrices. concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) in_features = concatenated_weights.shape[1] out_features = concatenated_weights.shape[0] # create a new single projection layer and copy over the weights. self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) self.to_qkv.weight.copy_(concatenated_weights) if self.use_bias: concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) self.to_qkv.bias.copy_(concatenated_bias) else: concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) in_features = concatenated_weights.shape[1] out_features = concatenated_weights.shape[0] self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) self.to_kv.weight.copy_(concatenated_weights) if self.use_bias: concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) self.to_kv.bias.copy_(concatenated_bias) # handle added projections for SD3 and others. if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): concatenated_weights = torch.cat( [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] ) in_features = concatenated_weights.shape[1] out_features = concatenated_weights.shape[0] self.to_added_qkv = nn.Linear( in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype ) self.to_added_qkv.weight.copy_(concatenated_weights) if self.added_proj_bias: concatenated_bias = torch.cat( [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] ) self.to_added_qkv.bias.copy_(concatenated_bias) self.fused_projections = fuse class AttnProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class CustomDiffusionAttnProcessor(nn.Module): r""" Processor for implementing attention for the Custom Diffusion method. Args: train_kv (`bool`, defaults to `True`): Whether to newly train the key and value matrices corresponding to the text features. train_q_out (`bool`, defaults to `True`): Whether to newly train query matrices corresponding to the latent image features. hidden_size (`int`, *optional*, defaults to `None`): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*, defaults to `None`): The number of channels in the `encoder_hidden_states`. out_bias (`bool`, defaults to `True`): Whether to include the bias parameter in `train_q_out`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. """ def __init__( self, train_kv: bool = True, train_q_out: bool = True, hidden_size: Optional[int] = None, cross_attention_dim: Optional[int] = None, out_bias: bool = True, dropout: float = 0.0, ): super().__init__() self.train_kv = train_kv self.train_q_out = train_q_out self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim # `_custom_diffusion` id for easy serialization and loading. if self.train_kv: self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) if self.train_q_out: self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) self.to_out_custom_diffusion = nn.ModuleList([]) self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if self.train_q_out: query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) else: query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) if encoder_hidden_states is None: crossattn = False encoder_hidden_states = hidden_states else: crossattn = True if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.train_kv: key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) key = key.to(attn.to_q.weight.dtype) value = value.to(attn.to_q.weight.dtype) else: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if crossattn: detach = torch.ones_like(key) detach[:, :1, :] = detach[:, :1, :] * 0.0 key = detach * key + (1 - detach) * key.detach() value = detach * value + (1 - detach) * value.detach() query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) if self.train_q_out: # linear proj hidden_states = self.to_out_custom_diffusion[0](hidden_states) # dropout hidden_states = self.to_out_custom_diffusion[1](hidden_states) else: # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class AttnAddedKVProcessor: r""" Processor for performing attention-related computations with extra learnable key and value matrices for the text encoder. """ def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class AttnAddedKVProcessor2_0: r""" Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra learnable key and value matrices for the text encoder. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query, out_dim=4) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class JointAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, index_block=None, copy_layers=None, *args, **kwargs, ) -> torch.FloatTensor: residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) context_input_ndim = encoder_hidden_states.ndim if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size = encoder_hidden_states.shape[0] # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) # K,V replcaement B, N, D = key.shape if copy_layers is not None and index_block in copy_layers: key[B // 2:][1:] = key[B // 2:][0:1] value[B // 2:][1:] = value[B // 2:][0:1] # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) # attention query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # Split the attention outputs. hidden_states, encoder_hidden_states = ( hidden_states[:, : residual.shape[1]], hidden_states[:, residual.shape[1] :], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if not attn.context_pre_only: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if context_input_ndim == 4: encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states, encoder_hidden_states class PAGJointAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, ) -> torch.FloatTensor: residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) context_input_ndim = encoder_hidden_states.ndim if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) # store the length of image patch sequences to create a mask that prevents interaction between patches # similar to making the self-attention map an identity matrix identity_block_size = hidden_states.shape[1] # chunk hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2) ################## original path ################## batch_size = encoder_hidden_states_org.shape[0] # `sample` projections. query_org = attn.to_q(hidden_states_org) key_org = attn.to_k(hidden_states_org) value_org = attn.to_v(hidden_states_org) # `context` projections. encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org) encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org) encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org) # attention query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1) key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1) value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1) inner_dim = key_org.shape[-1] head_dim = inner_dim // attn.heads query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) hidden_states_org = F.scaled_dot_product_attention( query_org, key_org, value_org, dropout_p=0.0, is_causal=False ) hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_org = hidden_states_org.to(query_org.dtype) # Split the attention outputs. hidden_states_org, encoder_hidden_states_org = ( hidden_states_org[:, : residual.shape[1]], hidden_states_org[:, residual.shape[1] :], ) # linear proj hidden_states_org = attn.to_out[0](hidden_states_org) # dropout hidden_states_org = attn.to_out[1](hidden_states_org) if not attn.context_pre_only: encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org) if input_ndim == 4: hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) if context_input_ndim == 4: encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape( batch_size, channel, height, width ) ################## perturbed path ################## batch_size = encoder_hidden_states_ptb.shape[0] # `sample` projections. query_ptb = attn.to_q(hidden_states_ptb) key_ptb = attn.to_k(hidden_states_ptb) value_ptb = attn.to_v(hidden_states_ptb) # `context` projections. encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb) encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb) encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb) # attention query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1) key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1) value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1) inner_dim = key_ptb.shape[-1] head_dim = inner_dim // attn.heads query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # create a full mask with all entries set to 0 seq_len = query_ptb.size(2) full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype) # set the attention value between image patches to -inf full_mask[:identity_block_size, :identity_block_size] = float("-inf") # set the diagonal of the attention value between image patches to 0 full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0) # expand the mask to match the attention weights shape full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions hidden_states_ptb = F.scaled_dot_product_attention( query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False ) hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype) # split the attention outputs. hidden_states_ptb, encoder_hidden_states_ptb = ( hidden_states_ptb[:, : residual.shape[1]], hidden_states_ptb[:, residual.shape[1] :], ) # linear proj hidden_states_ptb = attn.to_out[0](hidden_states_ptb) # dropout hidden_states_ptb = attn.to_out[1](hidden_states_ptb) if not attn.context_pre_only: encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb) if input_ndim == 4: hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) if context_input_ndim == 4: encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape( batch_size, channel, height, width ) ################ concat ############### hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb]) return hidden_states, encoder_hidden_states class PAGCFGJointAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, *args, **kwargs, ) -> torch.FloatTensor: residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) context_input_ndim = encoder_hidden_states.ndim if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) identity_block_size = hidden_states.shape[ 1 ] # patch embeddings width * height (correspond to self-attention map width or height) # chunk hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) ( encoder_hidden_states_uncond, encoder_hidden_states_org, encoder_hidden_states_ptb, ) = encoder_hidden_states.chunk(3) encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org]) ################## original path ################## batch_size = encoder_hidden_states_org.shape[0] # `sample` projections. query_org = attn.to_q(hidden_states_org) key_org = attn.to_k(hidden_states_org) value_org = attn.to_v(hidden_states_org) # `context` projections. encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org) encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org) encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org) # attention query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1) key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1) value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1) inner_dim = key_org.shape[-1] head_dim = inner_dim // attn.heads query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) hidden_states_org = F.scaled_dot_product_attention( query_org, key_org, value_org, dropout_p=0.0, is_causal=False ) hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_org = hidden_states_org.to(query_org.dtype) # Split the attention outputs. hidden_states_org, encoder_hidden_states_org = ( hidden_states_org[:, : residual.shape[1]], hidden_states_org[:, residual.shape[1] :], ) # linear proj hidden_states_org = attn.to_out[0](hidden_states_org) # dropout hidden_states_org = attn.to_out[1](hidden_states_org) if not attn.context_pre_only: encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org) if input_ndim == 4: hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) if context_input_ndim == 4: encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape( batch_size, channel, height, width ) ################## perturbed path ################## batch_size = encoder_hidden_states_ptb.shape[0] # `sample` projections. query_ptb = attn.to_q(hidden_states_ptb) key_ptb = attn.to_k(hidden_states_ptb) value_ptb = attn.to_v(hidden_states_ptb) # `context` projections. encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb) encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb) encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb) # attention query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1) key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1) value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1) inner_dim = key_ptb.shape[-1] head_dim = inner_dim // attn.heads query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # create a full mask with all entries set to 0 seq_len = query_ptb.size(2) full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype) # set the attention value between image patches to -inf full_mask[:identity_block_size, :identity_block_size] = float("-inf") # set the diagonal of the attention value between image patches to 0 full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0) # expand the mask to match the attention weights shape full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions hidden_states_ptb = F.scaled_dot_product_attention( query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False ) hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype) # split the attention outputs. hidden_states_ptb, encoder_hidden_states_ptb = ( hidden_states_ptb[:, : residual.shape[1]], hidden_states_ptb[:, residual.shape[1] :], ) # linear proj hidden_states_ptb = attn.to_out[0](hidden_states_ptb) # dropout hidden_states_ptb = attn.to_out[1](hidden_states_ptb) if not attn.context_pre_only: encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb) if input_ndim == 4: hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) if context_input_ndim == 4: encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape( batch_size, channel, height, width ) ################ concat ############### hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb]) return hidden_states, encoder_hidden_states class FusedJointAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, *args, **kwargs, ) -> torch.FloatTensor: residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) context_input_ndim = encoder_hidden_states.ndim if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size = encoder_hidden_states.shape[0] # `sample` projections. qkv = attn.to_qkv(hidden_states) split_size = qkv.shape[-1] // 3 query, key, value = torch.split(qkv, split_size, dim=-1) # `context` projections. encoder_qkv = attn.to_added_qkv(encoder_hidden_states) split_size = encoder_qkv.shape[-1] // 3 ( encoder_hidden_states_query_proj, encoder_hidden_states_key_proj, encoder_hidden_states_value_proj, ) = torch.split(encoder_qkv, split_size, dim=-1) # attention query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # Split the attention outputs. hidden_states, encoder_hidden_states = ( hidden_states[:, : residual.shape[1]], hidden_states[:, residual.shape[1] :], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if not attn.context_pre_only: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if context_input_ndim == 4: encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states, encoder_hidden_states class AuraFlowAttnProcessor2_0: """Attention processor used typically in processing Aura Flow.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): raise ImportError( "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " ) def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, *args, **kwargs, ) -> torch.FloatTensor: batch_size = hidden_states.shape[0] # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) # `context` projections. if encoder_hidden_states is not None: encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) # Reshape. inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, attn.heads, head_dim) value = value.view(batch_size, -1, attn.heads, head_dim) # Apply QK norm. if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Concatenate the projections. if encoder_hidden_states is not None: encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # Attention. hidden_states = F.scaled_dot_product_attention( query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # Split the attention outputs. if encoder_hidden_states is not None: hidden_states, encoder_hidden_states = ( hidden_states[:, encoder_hidden_states.shape[1] :], hidden_states[:, : encoder_hidden_states.shape[1]], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) if encoder_hidden_states is not None: return hidden_states, encoder_hidden_states else: return hidden_states class FusedAuraFlowAttnProcessor2_0: """Attention processor used typically in processing Aura Flow with fused projections.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"): raise ImportError( "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. " ) def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, *args, **kwargs, ) -> torch.FloatTensor: batch_size = hidden_states.shape[0] # `sample` projections. qkv = attn.to_qkv(hidden_states) split_size = qkv.shape[-1] // 3 query, key, value = torch.split(qkv, split_size, dim=-1) # `context` projections. if encoder_hidden_states is not None: encoder_qkv = attn.to_added_qkv(encoder_hidden_states) split_size = encoder_qkv.shape[-1] // 3 ( encoder_hidden_states_query_proj, encoder_hidden_states_key_proj, encoder_hidden_states_value_proj, ) = torch.split(encoder_qkv, split_size, dim=-1) # Reshape. inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, attn.heads, head_dim) value = value.view(batch_size, -1, attn.heads, head_dim) # Apply QK norm. if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Concatenate the projections. if encoder_hidden_states is not None: encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj) query = torch.cat([encoder_hidden_states_query_proj, query], dim=1) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # Attention. hidden_states = F.scaled_dot_product_attention( query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # Split the attention outputs. if encoder_hidden_states is not None: hidden_states, encoder_hidden_states = ( hidden_states[:, encoder_hidden_states.shape[1] :], hidden_states[:, : encoder_hidden_states.shape[1]], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) if encoder_hidden_states is not None: return hidden_states, encoder_hidden_states else: return hidden_states # YiYi to-do: refactor rope related functions/classes def apply_rope(xq, xk, freqs_cis): xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) class FluxSingleAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, attention_store = None, step_index=None, index_block=None, single_copy_blocks=None, ) -> torch.Tensor: input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # K,V replcaement B, N, D = key.shape inject_k_v = single_copy_blocks is not None and index_block in single_copy_blocks if inject_k_v: key[1:, 512:] = key[:1, 512:] value[1:, 512:] = value[:1, 512:] inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: # YiYi to-do: update uising apply_rotary_emb # from ..embeddings import apply_rotary_emb # query = apply_rotary_emb(query, image_rotary_emb) # key = apply_rotary_emb(key, image_rotary_emb) query, key = apply_rope(query, key, image_rotary_emb) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, attention_store=None, step_index=None, index_block=None, mm_copy_blocks=None, ) -> torch.FloatTensor: input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) context_input_ndim = encoder_hidden_states.ndim if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size = encoder_hidden_states.shape[0] # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) # K,V replcaement B, N, D = key.shape inject_k_v = mm_copy_blocks is not None and index_block in mm_copy_blocks if inject_k_v: key[1:] = key[:1] value[1:] = value[:1] inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: # YiYi to-do: update uising apply_rotary_emb # from ..embeddings import apply_rotary_emb # query = apply_rotary_emb(query, image_rotary_emb) # key = apply_rotary_emb(key, image_rotary_emb) query, key = apply_rope(query, key, image_rotary_emb) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) encoder_hidden_states, hidden_states = ( hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if context_input_ndim == 4: encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states, encoder_hidden_states class XFormersAttnAddedKVProcessor: r""" Processor for implementing memory efficient attention using xFormers. Args: attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class XFormersAttnProcessor: r""" Processor for implementing memory efficient attention using xFormers. Args: attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, key_tokens, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AttnProcessorNPU: r""" Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is not significant. """ def __init__(self): if not is_torch_npu_available(): raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) if query.dtype in (torch.float16, torch.bfloat16): hidden_states = torch_npu.npu_fusion_attention( query, key, value, attn.heads, input_layout="BNSD", pse=None, atten_mask=attention_mask, scale=1.0 / math.sqrt(query.shape[-1]), pre_tockens=65536, next_tockens=65536, keep_prob=1.0, sync=False, inner_precise=0, )[0] else: # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class StableAudioAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def apply_partial_rotary_emb( self, x: torch.Tensor, freqs_cis: Tuple[torch.Tensor], ) -> torch.Tensor: from .embeddings import apply_rotary_emb rot_dim = freqs_cis[0].shape[-1] x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:] x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2) out = torch.cat((x_rotated, x_unrotated), dim=-1) return out def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) head_dim = query.shape[-1] // attn.heads kv_heads = key.shape[-1] // head_dim query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply RoPE if needed if rotary_emb is not None: query_dtype = query.dtype key_dtype = key.dtype query = query.to(torch.float32) key = key.to(torch.float32) rot_dim = rotary_emb[0].shape[-1] query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) query = torch.cat((query_rotated, query_unrotated), dim=-1) if not attn.is_cross_attention: key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) key = torch.cat((key_rotated, key_unrotated), dim=-1) query = query.to(query_dtype) key = key.to(key_dtype) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class HunyuanAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) if not attn.is_cross_attention: key = apply_rotary_emb(key, image_rotary_emb) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class FusedHunyuanAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) if encoder_hidden_states is None: qkv = attn.to_qkv(hidden_states) split_size = qkv.shape[-1] // 3 query, key, value = torch.split(qkv, split_size, dim=-1) else: if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) query = attn.to_q(hidden_states) kv = attn.to_kv(encoder_hidden_states) split_size = kv.shape[-1] // 2 key, value = torch.split(kv, split_size, dim=-1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) if not attn.is_cross_attention: key = apply_rotary_emb(key, image_rotary_emb) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class PAGHunyuanAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) # chunk hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) # 1. Original Path batch_size, sequence_length, _ = ( hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states_org) if encoder_hidden_states is None: encoder_hidden_states = hidden_states_org elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) if not attn.is_cross_attention: key = apply_rotary_emb(key, image_rotary_emb) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states_org = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_org = hidden_states_org.to(query.dtype) # linear proj hidden_states_org = attn.to_out[0](hidden_states_org) # dropout hidden_states_org = attn.to_out[1](hidden_states_org) if input_ndim == 4: hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) # 2. Perturbed Path if attn.group_norm is not None: hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) hidden_states_ptb = attn.to_v(hidden_states_ptb) hidden_states_ptb = hidden_states_ptb.to(query.dtype) # linear proj hidden_states_ptb = attn.to_out[0](hidden_states_ptb) # dropout hidden_states_ptb = attn.to_out[1](hidden_states_ptb) if input_ndim == 4: hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) # cat hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class PAGCFGHunyuanAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) # chunk hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) # 1. Original Path batch_size, sequence_length, _ = ( hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states_org) if encoder_hidden_states is None: encoder_hidden_states = hidden_states_org elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) if not attn.is_cross_attention: key = apply_rotary_emb(key, image_rotary_emb) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states_org = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_org = hidden_states_org.to(query.dtype) # linear proj hidden_states_org = attn.to_out[0](hidden_states_org) # dropout hidden_states_org = attn.to_out[1](hidden_states_org) if input_ndim == 4: hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) # 2. Perturbed Path if attn.group_norm is not None: hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) hidden_states_ptb = attn.to_v(hidden_states_ptb) hidden_states_ptb = hidden_states_ptb.to(query.dtype) # linear proj hidden_states_ptb = attn.to_out[0](hidden_states_ptb) # dropout hidden_states_ptb = attn.to_out[1](hidden_states_ptb) if input_ndim == 4: hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) # cat hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class LuminaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, query_rotary_emb: Optional[torch.Tensor] = None, key_rotary_emb: Optional[torch.Tensor] = None, base_sequence_length: Optional[int] = None, ) -> torch.Tensor: from .embeddings import apply_rotary_emb input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape # Get Query-Key-Value Pair query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query_dim = query.shape[-1] inner_dim = key.shape[-1] head_dim = query_dim // attn.heads dtype = query.dtype # Get key-value heads kv_heads = inner_dim // head_dim # Apply Query-Key Norm if needed if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, kv_heads, head_dim) value = value.view(batch_size, -1, kv_heads, head_dim) # Apply RoPE if needed if query_rotary_emb is not None: query = apply_rotary_emb(query, query_rotary_emb, use_real=False) if key_rotary_emb is not None: key = apply_rotary_emb(key, key_rotary_emb, use_real=False) query, key = query.to(dtype), key.to(dtype) # Apply proportional attention if true if key_rotary_emb is None: softmax_scale = None else: if base_sequence_length is not None: softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale else: softmax_scale = attn.scale # perform Grouped-qurey Attention (GQA) n_rep = attn.heads // kv_heads if n_rep >= 1: key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, scale=softmax_scale ) hidden_states = hidden_states.transpose(1, 2).to(dtype) return hidden_states class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is currently 🧪 experimental in nature and can change in future. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) if encoder_hidden_states is None: qkv = attn.to_qkv(hidden_states) split_size = qkv.shape[-1] // 3 query, key, value = torch.split(qkv, split_size, dim=-1) else: if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) query = attn.to_q(hidden_states) kv = attn.to_kv(encoder_hidden_states) split_size = kv.shape[-1] // 2 key, value = torch.split(kv, split_size, dim=-1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class CustomDiffusionXFormersAttnProcessor(nn.Module): r""" Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. Args: train_kv (`bool`, defaults to `True`): Whether to newly train the key and value matrices corresponding to the text features. train_q_out (`bool`, defaults to `True`): Whether to newly train query matrices corresponding to the latent image features. hidden_size (`int`, *optional*, defaults to `None`): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*, defaults to `None`): The number of channels in the `encoder_hidden_states`. out_bias (`bool`, defaults to `True`): Whether to include the bias parameter in `train_q_out`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__( self, train_kv: bool = True, train_q_out: bool = False, hidden_size: Optional[int] = None, cross_attention_dim: Optional[int] = None, out_bias: bool = True, dropout: float = 0.0, attention_op: Optional[Callable] = None, ): super().__init__() self.train_kv = train_kv self.train_q_out = train_q_out self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.attention_op = attention_op # `_custom_diffusion` id for easy serialization and loading. if self.train_kv: self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) if self.train_q_out: self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) self.to_out_custom_diffusion = nn.ModuleList([]) self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if self.train_q_out: query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) else: query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) if encoder_hidden_states is None: crossattn = False encoder_hidden_states = hidden_states else: crossattn = True if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.train_kv: key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) key = key.to(attn.to_q.weight.dtype) value = value.to(attn.to_q.weight.dtype) else: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if crossattn: detach = torch.ones_like(key) detach[:, :1, :] = detach[:, :1, :] * 0.0 key = detach * key + (1 - detach) * key.detach() value = detach * value + (1 - detach) * value.detach() query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) if self.train_q_out: # linear proj hidden_states = self.to_out_custom_diffusion[0](hidden_states) # dropout hidden_states = self.to_out_custom_diffusion[1](hidden_states) else: # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class CustomDiffusionAttnProcessor2_0(nn.Module): r""" Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled dot-product attention. Args: train_kv (`bool`, defaults to `True`): Whether to newly train the key and value matrices corresponding to the text features. train_q_out (`bool`, defaults to `True`): Whether to newly train query matrices corresponding to the latent image features. hidden_size (`int`, *optional*, defaults to `None`): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*, defaults to `None`): The number of channels in the `encoder_hidden_states`. out_bias (`bool`, defaults to `True`): Whether to include the bias parameter in `train_q_out`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. """ def __init__( self, train_kv: bool = True, train_q_out: bool = True, hidden_size: Optional[int] = None, cross_attention_dim: Optional[int] = None, out_bias: bool = True, dropout: float = 0.0, ): super().__init__() self.train_kv = train_kv self.train_q_out = train_q_out self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim # `_custom_diffusion` id for easy serialization and loading. if self.train_kv: self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) if self.train_q_out: self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) self.to_out_custom_diffusion = nn.ModuleList([]) self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if self.train_q_out: query = self.to_q_custom_diffusion(hidden_states) else: query = attn.to_q(hidden_states) if encoder_hidden_states is None: crossattn = False encoder_hidden_states = hidden_states else: crossattn = True if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.train_kv: key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) key = key.to(attn.to_q.weight.dtype) value = value.to(attn.to_q.weight.dtype) else: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if crossattn: detach = torch.ones_like(key) detach[:, :1, :] = detach[:, :1, :] * 0.0 key = detach * key + (1 - detach) * key.detach() value = detach * value + (1 - detach) * value.detach() inner_dim = hidden_states.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if self.train_q_out: # linear proj hidden_states = self.to_out_custom_diffusion[0](hidden_states) # dropout hidden_states = self.to_out_custom_diffusion[1](hidden_states) else: # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class SlicedAttnProcessor: r""" Processor for implementing sliced attention. Args: slice_size (`int`, *optional*): The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and `attention_head_dim` must be a multiple of the `slice_size`. """ def __init__(self, slice_size: int): self.slice_size = slice_size def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) for i in range((batch_size_attention - 1) // self.slice_size + 1): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class SlicedAttnAddedKVProcessor: r""" Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder. Args: slice_size (`int`, *optional*): The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and `attention_head_dim` must be a multiple of the `slice_size`. """ def __init__(self, slice_size): self.slice_size = slice_size def __call__( self, attn: "Attention", hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) for i in range((batch_size_attention - 1) // self.slice_size + 1): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class SpatialNorm(nn.Module): """ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. Args: f_channels (`int`): The number of channels for input to group normalization layer, and output of the spatial norm layer. zq_channels (`int`): The number of channels for the quantized vector as described in the paper. """ def __init__( self, f_channels: int, zq_channels: int, ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: f_size = f.shape[-2:] zq = F.interpolate(zq, size=f_size, mode="nearest") norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f class IPAdapterAttnProcessor(nn.Module): r""" Attention processor for Multiple IP-Adapters. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): The context length of the image features. scale (`float` or List[`float`], defaults to 1.0): the weight scale of image prompt. """ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim if not isinstance(num_tokens, (tuple, list)): num_tokens = [num_tokens] self.num_tokens = num_tokens if not isinstance(scale, list): scale = [scale] * len(num_tokens) if len(scale) != len(num_tokens): raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") self.scale = scale self.to_k_ip = nn.ModuleList( [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] ) self.to_v_ip = nn.ModuleList( [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, scale: float = 1.0, ip_adapter_masks: Optional[torch.Tensor] = None, ): residual = hidden_states # separate ip_hidden_states from encoder_hidden_states if encoder_hidden_states is not None: if isinstance(encoder_hidden_states, tuple): encoder_hidden_states, ip_hidden_states = encoder_hidden_states else: deprecation_message = ( "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." ) deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], [encoder_hidden_states[:, end_pos:, :]], ) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) if ip_adapter_masks is not None: if not isinstance(ip_adapter_masks, List): # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): raise ValueError( f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " f"({len(ip_hidden_states)})" ) else: for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): if not isinstance(mask, torch.Tensor) or mask.ndim != 4: raise ValueError( "Each element of the ip_adapter_masks array should be a tensor with shape " "[1, num_images_for_ip_adapter, height, width]." " Please use `IPAdapterMaskProcessor` to preprocess your mask" ) if mask.shape[1] != ip_state.shape[1]: raise ValueError( f"Number of masks ({mask.shape[1]}) does not match " f"number of ip images ({ip_state.shape[1]}) at index {index}" ) if isinstance(scale, list) and not len(scale) == mask.shape[1]: raise ValueError( f"Number of masks ({mask.shape[1]}) does not match " f"number of scales ({len(scale)}) at index {index}" ) else: ip_adapter_masks = [None] * len(self.scale) # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): skip = False if isinstance(scale, list): if all(s == 0 for s in scale): skip = True elif scale == 0: skip = True if not skip: if mask is not None: if not isinstance(scale, list): scale = [scale] * mask.shape[1] current_num_images = mask.shape[1] for i in range(current_num_images): ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) mask_downsample = IPAdapterMaskProcessor.downsample( mask[:, i, :, :], batch_size, _current_ip_hidden_states.shape[1], _current_ip_hidden_states.shape[2], ) mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) else: ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class IPAdapterAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapter for PyTorch 2.0. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): The context length of the image features. scale (`float` or `List[float]`, defaults to 1.0): the weight scale of image prompt. """ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim if not isinstance(num_tokens, (tuple, list)): num_tokens = [num_tokens] self.num_tokens = num_tokens if not isinstance(scale, list): scale = [scale] * len(num_tokens) if len(scale) != len(num_tokens): raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") self.scale = scale self.to_k_ip = nn.ModuleList( [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] ) self.to_v_ip = nn.ModuleList( [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, scale: float = 1.0, ip_adapter_masks: Optional[torch.Tensor] = None, ): residual = hidden_states # separate ip_hidden_states from encoder_hidden_states if encoder_hidden_states is not None: if isinstance(encoder_hidden_states, tuple): encoder_hidden_states, ip_hidden_states = encoder_hidden_states else: deprecation_message = ( "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." ) deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], [encoder_hidden_states[:, end_pos:, :]], ) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if ip_adapter_masks is not None: if not isinstance(ip_adapter_masks, List): # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): raise ValueError( f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " f"({len(ip_hidden_states)})" ) else: for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): if not isinstance(mask, torch.Tensor) or mask.ndim != 4: raise ValueError( "Each element of the ip_adapter_masks array should be a tensor with shape " "[1, num_images_for_ip_adapter, height, width]." " Please use `IPAdapterMaskProcessor` to preprocess your mask" ) if mask.shape[1] != ip_state.shape[1]: raise ValueError( f"Number of masks ({mask.shape[1]}) does not match " f"number of ip images ({ip_state.shape[1]}) at index {index}" ) if isinstance(scale, list) and not len(scale) == mask.shape[1]: raise ValueError( f"Number of masks ({mask.shape[1]}) does not match " f"number of scales ({len(scale)}) at index {index}" ) else: ip_adapter_masks = [None] * len(self.scale) # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): skip = False if isinstance(scale, list): if all(s == 0 for s in scale): skip = True elif scale == 0: skip = True if not skip: if mask is not None: if not isinstance(scale, list): scale = [scale] * mask.shape[1] current_num_images = mask.shape[1] for i in range(current_num_images): ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 _current_ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) mask_downsample = IPAdapterMaskProcessor.downsample( mask[:, i, :, :], batch_size, _current_ip_hidden_states.shape[1], _current_ip_hidden_states.shape[2], ) mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) else: ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 current_ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class PAGIdentitySelfAttnProcessor2_0: r""" Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). PAG reference: https://arxiv.org/abs/2403.17377 """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) # chunk hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) # original path batch_size, sequence_length, _ = hidden_states_org.shape if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states_org) key = attn.to_k(hidden_states_org) value = attn.to_v(hidden_states_org) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states_org = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_org = hidden_states_org.to(query.dtype) # linear proj hidden_states_org = attn.to_out[0](hidden_states_org) # dropout hidden_states_org = attn.to_out[1](hidden_states_org) if input_ndim == 4: hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) # perturbed path (identity attention) batch_size, sequence_length, _ = hidden_states_ptb.shape if attn.group_norm is not None: hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) hidden_states_ptb = attn.to_v(hidden_states_ptb) hidden_states_ptb = hidden_states_ptb.to(query.dtype) # linear proj hidden_states_ptb = attn.to_out[0](hidden_states_ptb) # dropout hidden_states_ptb = attn.to_out[1](hidden_states_ptb) if input_ndim == 4: hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) # cat hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class PAGCFGIdentitySelfAttnProcessor2_0: r""" Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). PAG reference: https://arxiv.org/abs/2403.17377 """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) # chunk hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) # original path batch_size, sequence_length, _ = hidden_states_org.shape if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states_org) key = attn.to_k(hidden_states_org) value = attn.to_v(hidden_states_org) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states_org = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states_org = hidden_states_org.to(query.dtype) # linear proj hidden_states_org = attn.to_out[0](hidden_states_org) # dropout hidden_states_org = attn.to_out[1](hidden_states_org) if input_ndim == 4: hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) # perturbed path (identity attention) batch_size, sequence_length, _ = hidden_states_ptb.shape if attn.group_norm is not None: hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) value = attn.to_v(hidden_states_ptb) hidden_states_ptb = value hidden_states_ptb = hidden_states_ptb.to(query.dtype) # linear proj hidden_states_ptb = attn.to_out[0](hidden_states_ptb) # dropout hidden_states_ptb = attn.to_out[1](hidden_states_ptb) if input_ndim == 4: hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) # cat hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class LoRAAttnProcessor: def __init__(self): pass class LoRAAttnProcessor2_0: def __init__(self): pass class LoRAXFormersAttnProcessor: def __init__(self): pass class LoRAAttnAddedKVProcessor: def __init__(self): pass ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0, XFormersAttnAddedKVProcessor, ) CROSS_ATTENTION_PROCESSORS = ( AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor, SlicedAttnProcessor, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) AttentionProcessor = Union[ AttnProcessor, AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor, SlicedAttnProcessor, AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0, XFormersAttnAddedKVProcessor, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0, ] ================================================ FILE: src/diffusers/models/autoencoders/__init__.py ================================================ from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE from .vq_model import VQModel ================================================ FILE: src/diffusers/models/autoencoders/autoencoder_asym_kl.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils.accelerate_utils import apply_forward_hook from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): r""" Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of down block output channels. layers_per_down_block (`int`, *optional*, defaults to `1`): Number layers for down block. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): Tuple of upsample block types. up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of up block output channels. layers_per_up_block (`int`, *optional*, defaults to `1`): Number layers for up block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups to use for the first normalization layer in ResNet blocks. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), down_block_out_channels: Tuple[int, ...] = (64,), layers_per_down_block: int = 1, up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), up_block_out_channels: Tuple[int, ...] = (64,), layers_per_up_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, ) -> None: super().__init__() # pass init params to Encoder self.encoder = Encoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=down_block_out_channels, layers_per_block=layers_per_down_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, ) # pass init params to Decoder self.decoder = MaskConditionDecoder( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=up_block_out_channels, layers_per_block=layers_per_up_block, act_fn=act_fn, norm_num_groups=norm_num_groups, ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) self.use_slicing = False self.use_tiling = False self.register_to_config(block_out_channels=up_block_out_channels) self.register_to_config(force_upcast=False) @apply_forward_hook def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]: h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def _decode( self, z: torch.Tensor, image: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: z = self.post_quant_conv(z) dec = self.decoder(z, image, mask) if not return_dict: return (dec,) return DecoderOutput(sample=dec) @apply_forward_hook def decode( self, z: torch.Tensor, generator: Optional[torch.Generator] = None, image: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: decoded = self._decode(z, image, mask).sample if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) def forward( self, sample: torch.Tensor, mask: Optional[torch.Tensor] = None, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: r""" Args: sample (`torch.Tensor`): Input sample. mask (`torch.Tensor`, *optional*, defaults to `None`): Optional inpainting mask. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z, generator, sample, mask).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec) ================================================ FILE: src/diffusers/models/autoencoders/autoencoder_kl.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders.single_file_model import FromOriginalModelMixin from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, FusedAttnProcessor2_0, ) from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. force_upcast (`bool`, *optional*, default to `True`): If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without loosing too much precision in which case `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix mid_block_add_attention (`bool`, *optional*, default to `True`): If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the mid_block will only have resnet blocks """ _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, shift_factor: Optional[float] = None, latents_mean: Optional[Tuple[float]] = None, latents_std: Optional[Tuple[float]] = None, force_upcast: float = True, use_quant_conv: bool = True, use_post_quant_conv: bool = True, mid_block_add_attention: bool = True, ): super().__init__() # pass init params to Encoder self.encoder = Encoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, mid_block_add_attention=mid_block_add_attention, ) # pass init params to Decoder self.decoder = Decoder( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, mid_block_add_attention=mid_block_add_attention, ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None self.use_slicing = False self.use_tiling = False # only relevant if vae tiling is enabled self.tile_sample_min_size = self.config.sample_size sample_size = ( self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size ) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (Encoder, Decoder)): module.gradient_checkpointing = value def enable_tiling(self, use_tiling: bool = True): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.use_tiling = use_tiling def disable_tiling(self): r""" Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.enable_tiling(False) def enable_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True def disable_slicing(self): r""" Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.use_slicing = False @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): return self.tiled_encode(x, return_dict=return_dict) if self.use_slicing and x.shape[0] > 1: encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: h = self.encoder(x) if self.quant_conv is not None: moments = self.quant_conv(h) else: moments = h posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): return self.tiled_decode(z, return_dict=return_dict) if self.post_quant_conv is not None: z = self.post_quant_conv(z) dec = self.decoder(z) if not return_dict: return (dec,) return DecoderOutput(sample=dec) @apply_forward_hook def decode( self, z: torch.FloatTensor, return_dict: bool = True, generator=None ) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images. Args: z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: decoded = self._decode(z).sample if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[2], b.shape[2], blend_extent) for y in range(blend_extent): b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for x in range(blend_extent): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the output, but they should be much less noticeable. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent # Split the image into 512x512 tiles and encode them separately. rows = [] for i in range(0, x.shape[2], overlap_size): row = [] for j in range(0, x.shape[3], overlap_size): tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] tile = self.encoder(tile) if self.config.use_quant_conv: tile = self.quant_conv(tile) row.append(tile) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) moments = torch.cat(result_rows, dim=2) posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. Args: z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) row_limit = self.tile_sample_min_size - blend_extent # Split z into overlapping 64x64 tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] for i in range(0, z.shape[2], overlap_size): row = [] for j in range(0, z.shape[3], overlap_size): tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] if self.config.use_post_quant_conv: tile = self.post_quant_conv(tile) decoded = self.decoder(tile) row.append(decoded) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) dec = torch.cat(result_rows, dim=2) if not return_dict: return (dec,) return DecoderOutput(sample=dec) def forward( self, sample: torch.Tensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[DecoderOutput, torch.Tensor]: r""" Args: sample (`torch.Tensor`): Input sample. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) ================================================ FILE: src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py ================================================ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..downsampling import CogVideoXDownsample3D from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from ..upsampling import CogVideoXUpsample3D from .vae import DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name class CogVideoXSafeConv3d(nn.Conv3d): """ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. """ def forward(self, input: torch.Tensor) -> torch.Tensor: memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 # Set to 2GB, suitable for CuDNN if memory_count > 2: kernel_size = self.kernel_size[0] part_num = int(memory_count / 2) + 1 input_chunks = torch.chunk(input, part_num, dim=2) if kernel_size > 1: input_chunks = [input_chunks[0]] + [ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) for i in range(1, len(input_chunks)) ] output_chunks = [] for input_chunk in input_chunks: output_chunks.append(super().forward(input_chunk)) output = torch.cat(output_chunks, dim=2) return output else: return super().forward(input) class CogVideoXCausalConv3d(nn.Module): r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. Args: in_channels (int): Number of channels in the input tensor. out_channels (int): Number of output channels. kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel. stride (int, optional): Stride of the convolution. Default is 1. dilation (int, optional): Dilation rate of the convolution. Default is 1. pad_mode (str, optional): Padding mode. Default is "constant". """ def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int, int]], stride: int = 1, dilation: int = 1, pad_mode: str = "constant", ): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size,) * 3 time_kernel_size, height_kernel_size, width_kernel_size = kernel_size self.pad_mode = pad_mode time_pad = dilation * (time_kernel_size - 1) + (1 - stride) height_pad = height_kernel_size // 2 width_pad = width_kernel_size // 2 self.height_pad = height_pad self.width_pad = width_pad self.time_pad = time_pad self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) self.temporal_dim = 2 self.time_kernel_size = time_kernel_size stride = (stride, 1, 1) dilation = (dilation, 1, 1) self.conv = CogVideoXSafeConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, ) self.conv_cache = None def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: dim = self.temporal_dim kernel_size = self.time_kernel_size if kernel_size == 1: return inputs inputs = inputs.transpose(0, dim) if self.conv_cache is not None: inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0) else: inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0) inputs = inputs.transpose(0, dim).contiguous() return inputs def _clear_fake_context_parallel_cache(self): del self.conv_cache self.conv_cache = None def forward(self, inputs: torch.Tensor) -> torch.Tensor: input_parallel = self.fake_context_parallel_forward(inputs) self._clear_fake_context_parallel_cache() self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) output_parallel = self.conv(input_parallel) output = output_parallel return output class CogVideoXSpatialNorm3D(nn.Module): r""" Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific to 3D-video like data. CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. Args: f_channels (`int`): The number of channels for input to group normalization layer, and output of the spatial norm layer. zq_channels (`int`): The number of channels for the quantized vector as described in the paper. """ def __init__( self, f_channels: int, zq_channels: int, groups: int = 32, ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: if f.shape[2] > 1 and f.shape[2] % 2 == 1: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] z_first = F.interpolate(z_first, size=f_first_size) z_rest = F.interpolate(z_rest, size=f_rest_size) zq = torch.cat([z_first, z_rest], dim=2) else: zq = F.interpolate(zq, size=f.shape[-3:]) norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f class CogVideoXResnetBlock3D(nn.Module): r""" A 3D ResNet block used in the CogVideoX model. Args: in_channels (int): Number of input channels. out_channels (Optional[int], optional): Number of output channels. If None, defaults to `in_channels`. Default is None. dropout (float, optional): Dropout rate. Default is 0.0. temb_channels (int, optional): Number of time embedding channels. Default is 512. groups (int, optional): Number of groups for group normalization. Default is 32. eps (float, optional): Epsilon value for normalization layers. Default is 1e-6. non_linearity (str, optional): Activation function to use. Default is "swish". conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False. spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. pad_mode (str, optional): Padding mode. Default is "first". """ def __init__( self, in_channels: int, out_channels: Optional[int] = None, dropout: float = 0.0, temb_channels: int = 512, groups: int = 32, eps: float = 1e-6, non_linearity: str = "swish", conv_shortcut: bool = False, spatial_norm_dim: Optional[int] = None, pad_mode: str = "first", ): super().__init__() out_channels = out_channels or in_channels self.in_channels = in_channels self.out_channels = out_channels self.nonlinearity = get_activation(non_linearity) self.use_conv_shortcut = conv_shortcut if spatial_norm_dim is None: self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) else: self.norm1 = CogVideoXSpatialNorm3D( f_channels=in_channels, zq_channels=spatial_norm_dim, groups=groups, ) self.norm2 = CogVideoXSpatialNorm3D( f_channels=out_channels, zq_channels=spatial_norm_dim, groups=groups, ) self.conv1 = CogVideoXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) if temb_channels > 0: self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels) self.dropout = nn.Dropout(dropout) self.conv2 = CogVideoXCausalConv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = CogVideoXCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode ) else: self.conv_shortcut = CogVideoXSafeConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 ) def forward( self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = inputs if zq is not None: hidden_states = self.norm1(hidden_states, zq) else: hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) if temb is not None: hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] if zq is not None: hidden_states = self.norm2(hidden_states, zq) else: hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.in_channels != self.out_channels: inputs = self.conv_shortcut(inputs) hidden_states = hidden_states + inputs return hidden_states class CogVideoXDownBlock3D(nn.Module): r""" A downsampling block used in the CogVideoX model. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. temb_channels (int): Number of time embedding channels. dropout (float, optional): Dropout rate. Default is 0.0. num_layers (int, optional): Number of layers in the block. Default is 1. resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True. downsample_padding (int, optional): Padding for the downsampling layer. Default is 0. compress_time (bool, optional): If True, apply temporal compression. Default is False. pad_mode (str, optional): Padding mode. Default is "first". """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", resnet_groups: int = 32, add_downsample: bool = True, downsample_padding: int = 0, compress_time: bool = False, pad_mode: str = "first", ): super().__init__() resnets = [] for i in range(num_layers): in_channel = in_channels if i == 0 else out_channels resnets.append( CogVideoXResnetBlock3D( in_channels=in_channel, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=resnet_groups, eps=resnet_eps, non_linearity=resnet_act_fn, pad_mode=pad_mode, ) ) self.resnets = nn.ModuleList(resnets) self.downsamplers = None if add_downsample: self.downsamplers = nn.ModuleList( [ CogVideoXDownsample3D( out_channels, out_channels, padding=downsample_padding, compress_time=compress_time ) ] ) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, ) -> torch.Tensor: for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): return module(*inputs) return create_forward hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, zq ) else: hidden_states = resnet(hidden_states, temb, zq) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class CogVideoXMidBlock3D(nn.Module): r""" A middle block used in the CogVideoX model. Args: in_channels (int): Number of input channels. temb_channels (int): Number of time embedding channels. dropout (float, optional): Dropout rate. Default is 0.0. num_layers (int, optional): Number of layers in the block. Default is 1. resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None. pad_mode (str, optional): Padding mode. Default is "first". """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", resnet_groups: int = 32, spatial_norm_dim: Optional[int] = None, pad_mode: str = "first", ): super().__init__() resnets = [] for _ in range(num_layers): resnets.append( CogVideoXResnetBlock3D( in_channels=in_channels, out_channels=in_channels, dropout=dropout, temb_channels=temb_channels, groups=resnet_groups, eps=resnet_eps, spatial_norm_dim=spatial_norm_dim, non_linearity=resnet_act_fn, pad_mode=pad_mode, ) ) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, ) -> torch.Tensor: for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): return module(*inputs) return create_forward hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, zq ) else: hidden_states = resnet(hidden_states, temb, zq) return hidden_states class CogVideoXUpBlock3D(nn.Module): r""" An upsampling block used in the CogVideoX model. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. temb_channels (int): Number of time embedding channels. dropout (float, optional): Dropout rate. Default is 0.0. num_layers (int, optional): Number of layers in the block. Default is 1. resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6. resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish". resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32. spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16. add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True. upsample_padding (int, optional): Padding for the upsampling layer. Default is 1. compress_time (bool, optional): If True, apply temporal compression. Default is False. pad_mode (str, optional): Padding mode. Default is "first". """ def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", resnet_groups: int = 32, spatial_norm_dim: int = 16, add_upsample: bool = True, upsample_padding: int = 1, compress_time: bool = False, pad_mode: str = "first", ): super().__init__() resnets = [] for i in range(num_layers): in_channel = in_channels if i == 0 else out_channels resnets.append( CogVideoXResnetBlock3D( in_channels=in_channel, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=resnet_groups, eps=resnet_eps, non_linearity=resnet_act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, ) ) self.resnets = nn.ModuleList(resnets) self.upsamplers = None if add_upsample: self.upsamplers = nn.ModuleList( [ CogVideoXUpsample3D( out_channels, out_channels, padding=upsample_padding, compress_time=compress_time ) ] ) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, zq: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Forward method of the `CogVideoXUpBlock3D` class.""" for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): return module(*inputs) return create_forward hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, zq ) else: hidden_states = resnet(hidden_states, temb, zq) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class CogVideoXEncoder3D(nn.Module): r""" The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. double_z (`bool`, *optional*, defaults to `True`): Whether to double the number of output channels for the last block. """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int = 3, out_channels: int = 16, down_block_types: Tuple[str, ...] = ( "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", ), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, act_fn: str = "silu", norm_eps: float = 1e-6, norm_num_groups: int = 32, dropout: float = 0.0, pad_mode: str = "first", temporal_compression_ratio: float = 4, ): super().__init__() # log2 of temporal_compress_times temporal_compress_level = int(np.log2(temporal_compression_ratio)) self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) self.down_blocks = nn.ModuleList([]) # down blocks output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 compress_time = i < temporal_compress_level if down_block_type == "CogVideoXDownBlock3D": down_block = CogVideoXDownBlock3D( in_channels=input_channel, out_channels=output_channel, temb_channels=0, dropout=dropout, num_layers=layers_per_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, add_downsample=not is_final_block, compress_time=compress_time, ) else: raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") self.down_blocks.append(down_block) # mid block self.mid_block = CogVideoXMidBlock3D( in_channels=block_out_channels[-1], temb_channels=0, dropout=dropout, num_layers=2, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, pad_mode=pad_mode, ) self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = CogVideoXCausalConv3d( block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode ) self.gradient_checkpointing = False def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: r"""The forward method of the `CogVideoXEncoder3D` class.""" hidden_states = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # 1. Down for down_block in self.down_blocks: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), hidden_states, temb, None ) # 2. Mid hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), hidden_states, temb, None ) else: # 1. Down for down_block in self.down_blocks: hidden_states = down_block(hidden_states, temb, None) # 2. Mid hidden_states = self.mid_block(hidden_states, temb, None) # 3. Post-process hidden_states = self.norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states class CogVideoXDecoder3D(nn.Module): r""" The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. norm_type (`str`, *optional*, defaults to `"group"`): The normalization type to use. Can be either `"group"` or `"spatial"`. """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int = 16, out_channels: int = 3, up_block_types: Tuple[str, ...] = ( "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", ), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, act_fn: str = "silu", norm_eps: float = 1e-6, norm_num_groups: int = 32, dropout: float = 0.0, pad_mode: str = "first", temporal_compression_ratio: float = 4, ): super().__init__() reversed_block_out_channels = list(reversed(block_out_channels)) self.conv_in = CogVideoXCausalConv3d( in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode ) # mid block self.mid_block = CogVideoXMidBlock3D( in_channels=reversed_block_out_channels[0], temb_channels=0, num_layers=2, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, spatial_norm_dim=in_channels, pad_mode=pad_mode, ) # up blocks self.up_blocks = nn.ModuleList([]) output_channel = reversed_block_out_channels[0] temporal_compress_level = int(np.log2(temporal_compression_ratio)) for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 compress_time = i < temporal_compress_level if up_block_type == "CogVideoXUpBlock3D": up_block = CogVideoXUpBlock3D( in_channels=prev_output_channel, out_channels=output_channel, temb_channels=0, dropout=dropout, num_layers=layers_per_block + 1, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, spatial_norm_dim=in_channels, add_upsample=not is_final_block, compress_time=compress_time, pad_mode=pad_mode, ) prev_output_channel = output_channel else: raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") self.up_blocks.append(up_block) self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups) self.conv_act = nn.SiLU() self.conv_out = CogVideoXCausalConv3d( reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode ) self.gradient_checkpointing = False def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: r"""The forward method of the `CogVideoXDecoder3D` class.""" hidden_states = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # 1. Mid hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), hidden_states, temb, sample ) # 2. Up for up_block in self.up_blocks: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), hidden_states, temb, sample ) else: # 1. Mid hidden_states = self.mid_block(hidden_states, temb, sample) # 2. Up for up_block in self.up_blocks: hidden_states = up_block(hidden_states, temb, sample) # 3. Post-process hidden_states = self.norm_out(hidden_states, sample) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [CogVideoX](https://github.com/THUDM/CogVideo). This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. sample_size (`int`, *optional*, defaults to `32`): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. force_upcast (`bool`, *optional*, default to `True`): If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without loosing too much precision in which case `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix """ _supports_gradient_checkpointing = True _no_split_modules = ["CogVideoXResnetBlock3D"] @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str] = ( "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", ), up_block_types: Tuple[str] = ( "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", ), block_out_channels: Tuple[int] = (128, 256, 256, 512), latent_channels: int = 16, layers_per_block: int = 3, act_fn: str = "silu", norm_eps: float = 1e-6, norm_num_groups: int = 32, temporal_compression_ratio: float = 4, sample_size: int = 256, scaling_factor: float = 1.15258426, shift_factor: Optional[float] = None, latents_mean: Optional[Tuple[float]] = None, latents_std: Optional[Tuple[float]] = None, force_upcast: float = True, use_quant_conv: bool = False, use_post_quant_conv: bool = False, ): super().__init__() self.encoder = CogVideoXEncoder3D( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_eps=norm_eps, norm_num_groups=norm_num_groups, temporal_compression_ratio=temporal_compression_ratio, ) self.decoder = CogVideoXDecoder3D( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_eps=norm_eps, norm_num_groups=norm_num_groups, temporal_compression_ratio=temporal_compression_ratio, ) self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None self.use_slicing = False self.use_tiling = False self.tile_sample_min_size = self.config.sample_size sample_size = ( self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size ) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): module.gradient_checkpointing = value def clear_fake_context_parallel_cache(self): for name, module in self.named_modules(): if isinstance(module, CogVideoXCausalConv3d): logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") module._clear_fake_context_parallel_cache() @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ h = self.encoder(x) if self.quant_conv is not None: h = self.quant_conv(h) posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) @apply_forward_hook def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: """ Decode a batch of images. Args: z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.post_quant_conv is not None: z = self.post_quant_conv(z) dec = self.decoder(z) if not return_dict: return (dec,) return DecoderOutput(sample=dec) def forward( self, sample: torch.Tensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[torch.Tensor, torch.Tensor]: x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z) if not return_dict: return (dec,) return dec ================================================ FILE: src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder class TemporalDecoder(nn.Module): def __init__( self, in_channels: int = 4, out_channels: int = 3, block_out_channels: Tuple[int] = (128, 256, 512, 512), layers_per_block: int = 2, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) self.mid_block = MidBlockTemporalDecoder( num_layers=self.layers_per_block, in_channels=block_out_channels[-1], out_channels=block_out_channels[-1], attention_head_dim=block_out_channels[-1], ) # up self.up_blocks = nn.ModuleList([]) reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i in range(len(block_out_channels)): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = UpBlockTemporalDecoder( num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, add_upsample=not is_final_block, ) self.up_blocks.append(up_block) prev_output_channel = output_channel self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = torch.nn.Conv2d( in_channels=block_out_channels[0], out_channels=out_channels, kernel_size=3, padding=1, ) conv_out_kernel_size = (3, 1, 1) padding = [int(k // 2) for k in conv_out_kernel_size] self.time_conv_out = torch.nn.Conv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=conv_out_kernel_size, padding=padding, ) self.gradient_checkpointing = False def forward( self, sample: torch.Tensor, image_only_indicator: torch.Tensor, num_frames: int = 1, ) -> torch.Tensor: r"""The forward method of the `Decoder` class.""" sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, image_only_indicator, use_reentrant=False, ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, image_only_indicator, use_reentrant=False, ) else: # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, image_only_indicator, ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, image_only_indicator, ) else: # middle sample = self.mid_block(sample, image_only_indicator=image_only_indicator) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = up_block(sample, image_only_indicator=image_only_indicator) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) batch_frames, channels, height, width = sample.shape batch_size = batch_frames // num_frames sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) sample = self.time_conv_out(sample) sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) return sample class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block. latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. force_upcast (`bool`, *optional*, default to `True`): If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without loosing too much precision in which case `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock2D",), block_out_channels: Tuple[int] = (64,), layers_per_block: int = 1, latent_channels: int = 4, sample_size: int = 32, scaling_factor: float = 0.18215, force_upcast: float = True, ): super().__init__() # pass init params to Encoder self.encoder = Encoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, double_z=True, ) # pass init params to Decoder self.decoder = TemporalDecoder( in_channels=latent_channels, out_channels=out_channels, block_out_channels=block_out_channels, layers_per_block=layers_per_block, ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) sample_size = ( self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size ) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (Encoder, TemporalDecoder)): module.gradient_checkpointing = value @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) @apply_forward_hook def decode( self, z: torch.Tensor, num_frames: int, return_dict: bool = True, ) -> Union[DecoderOutput, torch.Tensor]: """ Decode a batch of images. Args: z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ batch_size = z.shape[0] // num_frames image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device) decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator) if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) def forward( self, sample: torch.Tensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, num_frames: int = 1, ) -> Union[DecoderOutput, torch.Tensor]: r""" Args: sample (`torch.Tensor`): Input sample. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z, num_frames=num_frames).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec) ================================================ FILE: src/diffusers/models/autoencoders/autoencoder_oobleck.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from torch.nn.utils import weight_norm from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ...utils.accelerate_utils import apply_forward_hook from ...utils.torch_utils import randn_tensor from ..modeling_utils import ModelMixin class Snake1d(nn.Module): """ A 1-dimensional Snake activation function module. """ def __init__(self, hidden_dim, logscale=True): super().__init__() self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1)) self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1)) self.alpha.requires_grad = True self.beta.requires_grad = True self.logscale = logscale def forward(self, hidden_states): shape = hidden_states.shape alpha = self.alpha if not self.logscale else torch.exp(self.alpha) beta = self.beta if not self.logscale else torch.exp(self.beta) hidden_states = hidden_states.reshape(shape[0], shape[1], -1) hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2) hidden_states = hidden_states.reshape(shape) return hidden_states class OobleckResidualUnit(nn.Module): """ A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations. """ def __init__(self, dimension: int = 16, dilation: int = 1): super().__init__() pad = ((7 - 1) * dilation) // 2 self.snake1 = Snake1d(dimension) self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)) self.snake2 = Snake1d(dimension) self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1)) def forward(self, hidden_state): """ Forward pass through the residual unit. Args: hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`): Input tensor . Returns: output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`) Input tensor after passing through the residual unit. """ output_tensor = hidden_state output_tensor = self.conv1(self.snake1(output_tensor)) output_tensor = self.conv2(self.snake2(output_tensor)) padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2 if padding > 0: hidden_state = hidden_state[..., padding:-padding] output_tensor = hidden_state + output_tensor return output_tensor class OobleckEncoderBlock(nn.Module): """Encoder block used in Oobleck encoder.""" def __init__(self, input_dim, output_dim, stride: int = 1): super().__init__() self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1) self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3) self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9) self.snake1 = Snake1d(input_dim) self.conv1 = weight_norm( nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)) ) def forward(self, hidden_state): hidden_state = self.res_unit1(hidden_state) hidden_state = self.res_unit2(hidden_state) hidden_state = self.snake1(self.res_unit3(hidden_state)) hidden_state = self.conv1(hidden_state) return hidden_state class OobleckDecoderBlock(nn.Module): """Decoder block used in Oobleck decoder.""" def __init__(self, input_dim, output_dim, stride: int = 1): super().__init__() self.snake1 = Snake1d(input_dim) self.conv_t1 = weight_norm( nn.ConvTranspose1d( input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2), ) ) self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1) self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3) self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9) def forward(self, hidden_state): hidden_state = self.snake1(hidden_state) hidden_state = self.conv_t1(hidden_state) hidden_state = self.res_unit1(hidden_state) hidden_state = self.res_unit2(hidden_state) hidden_state = self.res_unit3(hidden_state) return hidden_state class OobleckDiagonalGaussianDistribution(object): def __init__(self, parameters: torch.Tensor, deterministic: bool = False): self.parameters = parameters self.mean, self.scale = parameters.chunk(2, dim=1) self.std = nn.functional.softplus(self.scale) + 1e-4 self.var = self.std * self.std self.logvar = torch.log(self.var) self.deterministic = deterministic def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: # make sure sample is on the same device as the parameters and has same dtype sample = randn_tensor( self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype, ) x = self.mean + self.std * sample return x def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean() else: normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var var_ratio = self.var / other.var logvar_diff = self.logvar - other.logvar kl = normalized_diff + var_ratio + logvar_diff - 1 kl = kl.sum(1).mean() return kl def mode(self) -> torch.Tensor: return self.mean @dataclass class AutoencoderOobleckOutput(BaseOutput): """ Output of AutoencoderOobleck encoding method. Args: latent_dist (`OobleckDiagonalGaussianDistribution`): Encoded outputs of `Encoder` represented as the mean and standard deviation of `OobleckDiagonalGaussianDistribution`. `OobleckDiagonalGaussianDistribution` allows for sampling latents from the distribution. """ latent_dist: "OobleckDiagonalGaussianDistribution" # noqa: F821 @dataclass class OobleckDecoderOutput(BaseOutput): r""" Output of decoding method. Args: sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`): The decoded output sample from the last layer of the model. """ sample: torch.Tensor class OobleckEncoder(nn.Module): """Oobleck Encoder""" def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples): super().__init__() strides = downsampling_ratios channel_multiples = [1] + channel_multiples # Create first convolution self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3)) self.block = [] # Create EncoderBlocks that double channels as they downsample by `stride` for stride_index, stride in enumerate(strides): self.block += [ OobleckEncoderBlock( input_dim=encoder_hidden_size * channel_multiples[stride_index], output_dim=encoder_hidden_size * channel_multiples[stride_index + 1], stride=stride, ) ] self.block = nn.ModuleList(self.block) d_model = encoder_hidden_size * channel_multiples[-1] self.snake1 = Snake1d(d_model) self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1)) def forward(self, hidden_state): hidden_state = self.conv1(hidden_state) for module in self.block: hidden_state = module(hidden_state) hidden_state = self.snake1(hidden_state) hidden_state = self.conv2(hidden_state) return hidden_state class OobleckDecoder(nn.Module): """Oobleck Decoder""" def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples): super().__init__() strides = upsampling_ratios channel_multiples = [1] + channel_multiples # Add first conv layer self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3)) # Add upsampling + MRF blocks block = [] for stride_index, stride in enumerate(strides): block += [ OobleckDecoderBlock( input_dim=channels * channel_multiples[len(strides) - stride_index], output_dim=channels * channel_multiples[len(strides) - stride_index - 1], stride=stride, ) ] self.block = nn.ModuleList(block) output_dim = channels self.snake1 = Snake1d(output_dim) self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False)) def forward(self, hidden_state): hidden_state = self.conv1(hidden_state) for layer in self.block: hidden_state = layer(hidden_state) hidden_state = self.snake1(hidden_state) hidden_state = self.conv2(hidden_state) return hidden_state class AutoencoderOobleck(ModelMixin, ConfigMixin): r""" An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First introduced in Stable Audio. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: encoder_hidden_size (`int`, *optional*, defaults to 128): Intermediate representation dimension for the encoder. downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`): Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder. channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`): Multiples used to determine the hidden sizes of the hidden layers. decoder_channels (`int`, *optional*, defaults to 128): Intermediate representation dimension for the decoder. decoder_input_channels (`int`, *optional*, defaults to 64): Input dimension for the decoder. Corresponds to the latent dimension. audio_channels (`int`, *optional*, defaults to 2): Number of channels in the audio data. Either 1 for mono or 2 for stereo. sampling_rate (`int`, *optional*, defaults to 44100): The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). """ _supports_gradient_checkpointing = False @register_to_config def __init__( self, encoder_hidden_size=128, downsampling_ratios=[2, 4, 4, 8, 8], channel_multiples=[1, 2, 4, 8, 16], decoder_channels=128, decoder_input_channels=64, audio_channels=2, sampling_rate=44100, ): super().__init__() self.encoder_hidden_size = encoder_hidden_size self.downsampling_ratios = downsampling_ratios self.decoder_channels = decoder_channels self.upsampling_ratios = downsampling_ratios[::-1] self.hop_length = int(np.prod(downsampling_ratios)) self.sampling_rate = sampling_rate self.encoder = OobleckEncoder( encoder_hidden_size=encoder_hidden_size, audio_channels=audio_channels, downsampling_ratios=downsampling_ratios, channel_multiples=channel_multiples, ) self.decoder = OobleckDecoder( channels=decoder_channels, input_channels=decoder_input_channels, audio_channels=audio_channels, upsampling_ratios=self.upsampling_ratios, channel_multiples=channel_multiples, ) self.use_slicing = False def enable_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True def disable_slicing(self): r""" Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.use_slicing = False @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]: """ Encode a batch of images into latents. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and x.shape[0] > 1: encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: h = self.encoder(x) posterior = OobleckDiagonalGaussianDistribution(h) if not return_dict: return (posterior,) return AutoencoderOobleckOutput(latent_dist=posterior) def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]: dec = self.decoder(z) if not return_dict: return (dec,) return OobleckDecoderOutput(sample=dec) @apply_forward_hook def decode( self, z: torch.FloatTensor, return_dict: bool = True, generator=None ) -> Union[OobleckDecoderOutput, torch.FloatTensor]: """ Decode a batch of images. Args: z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.OobleckDecoderOutput`] instead of a plain tuple. Returns: [`~models.vae.OobleckDecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.OobleckDecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: decoded = self._decode(z).sample if not return_dict: return (decoded,) return OobleckDecoderOutput(sample=decoded) def forward( self, sample: torch.Tensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[OobleckDecoderOutput, torch.Tensor]: r""" Args: sample (`torch.Tensor`): Input sample. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`OobleckDecoderOutput`] instead of a plain tuple. """ x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z).sample if not return_dict: return (dec,) return OobleckDecoderOutput(sample=dec) ================================================ FILE: src/diffusers/models/autoencoders/autoencoder_tiny.py ================================================ # Copyright 2024 Ollin Boer Bohan and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ...utils.accelerate_utils import apply_forward_hook from ..modeling_utils import ModelMixin from .vae import DecoderOutput, DecoderTiny, EncoderTiny @dataclass class AutoencoderTinyOutput(BaseOutput): """ Output of AutoencoderTiny encoding method. Args: latents (`torch.Tensor`): Encoded outputs of the `Encoder`. """ latents: torch.Tensor class AutoencoderTiny(ModelMixin, ConfigMixin): r""" A tiny distilled VAE model for encoding images into latents and decoding latent representations into images. [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`. This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): Tuple of integers representing the number of output channels for each encoder block. The length of the tuple should be equal to the number of encoder blocks. decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): Tuple of integers representing the number of output channels for each decoder block. The length of the tuple should be equal to the number of decoder blocks. act_fn (`str`, *optional*, defaults to `"relu"`): Activation function to be used throughout the model. latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent representation. The latent space acts as a compressed representation of the input image. upsampling_scaling_factor (`int`, *optional*, defaults to 2): Scaling factor for upsampling in the decoder. It determines the size of the output image during the upsampling process. num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`): Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The length of the tuple should be equal to the number of stages in the encoder. Each stage has a different number of encoder blocks. num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`): Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The length of the tuple should be equal to the number of stages in the decoder. Each stage has a different number of decoder blocks. latent_magnitude (`float`, *optional*, defaults to 3.0): Magnitude of the latent representation. This parameter scales the latent representation values to control the extent of information preservation. latent_shift (float, *optional*, defaults to 0.5): Shift applied to the latent representation. This parameter controls the center of the latent space. scaling_factor (`float`, *optional*, defaults to 1.0): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder, however, no such scaling factor was used, hence the value of 1.0 as the default. force_upcast (`bool`, *optional*, default to `False`): If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without losing too much precision, in which case `force_upcast` can be set to `False` (see this fp16-friendly [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), act_fn: str = "relu", upsample_fn: str = "nearest", latent_channels: int = 4, upsampling_scaling_factor: int = 2, num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3), num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1), latent_magnitude: int = 3, latent_shift: float = 0.5, force_upcast: bool = False, scaling_factor: float = 1.0, shift_factor: float = 0.0, ): super().__init__() if len(encoder_block_out_channels) != len(num_encoder_blocks): raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.") if len(decoder_block_out_channels) != len(num_decoder_blocks): raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.") self.encoder = EncoderTiny( in_channels=in_channels, out_channels=latent_channels, num_blocks=num_encoder_blocks, block_out_channels=encoder_block_out_channels, act_fn=act_fn, ) self.decoder = DecoderTiny( in_channels=latent_channels, out_channels=out_channels, num_blocks=num_decoder_blocks, block_out_channels=decoder_block_out_channels, upsampling_scaling_factor=upsampling_scaling_factor, act_fn=act_fn, upsample_fn=upsample_fn, ) self.latent_magnitude = latent_magnitude self.latent_shift = latent_shift self.scaling_factor = scaling_factor self.use_slicing = False self.use_tiling = False # only relevant if vae tiling is enabled self.spatial_scale_factor = 2**out_channels self.tile_overlap_factor = 0.125 self.tile_sample_min_size = 512 self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor self.register_to_config(block_out_channels=decoder_block_out_channels) self.register_to_config(force_upcast=False) def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (EncoderTiny, DecoderTiny)): module.gradient_checkpointing = value def scale_latents(self, x: torch.Tensor) -> torch.Tensor: """raw latents -> [0, 1]""" return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) def unscale_latents(self, x: torch.Tensor) -> torch.Tensor: """[0, 1] -> raw latents""" return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) def enable_slicing(self) -> None: r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True def disable_slicing(self) -> None: r""" Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.use_slicing = False def enable_tiling(self, use_tiling: bool = True) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.use_tiling = use_tiling def disable_tiling(self) -> None: r""" Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.enable_tiling(False) def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. Args: x (`torch.Tensor`): Input batch of images. Returns: `torch.Tensor`: Encoded batch of images. """ # scale of encoder output relative to input sf = self.spatial_scale_factor tile_size = self.tile_sample_min_size # number of pixels to blend and to traverse between tile blend_size = int(tile_size * self.tile_overlap_factor) traverse_size = tile_size - blend_size # tiles index (up/left) ti = range(0, x.shape[-2], traverse_size) tj = range(0, x.shape[-1], traverse_size) # mask for blending blend_masks = torch.stack( torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij") ) blend_masks = blend_masks.clamp(0, 1).to(x.device) # output array out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device) for i in ti: for j in tj: tile_in = x[..., i : i + tile_size, j : j + tile_size] # tile result tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf] tile = self.encoder(tile_in) h, w = tile.shape[-2], tile.shape[-1] # blend tile result into output blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] blend_mask = blend_mask_i * blend_mask_j tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w] tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) return out def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. Args: x (`torch.Tensor`): Input batch of images. Returns: `torch.Tensor`: Encoded batch of images. """ # scale of decoder output relative to input sf = self.spatial_scale_factor tile_size = self.tile_latent_min_size # number of pixels to blend and to traverse between tiles blend_size = int(tile_size * self.tile_overlap_factor) traverse_size = tile_size - blend_size # tiles index (up/left) ti = range(0, x.shape[-2], traverse_size) tj = range(0, x.shape[-1], traverse_size) # mask for blending blend_masks = torch.stack( torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij") ) blend_masks = blend_masks.clamp(0, 1).to(x.device) # output array out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device) for i in ti: for j in tj: tile_in = x[..., i : i + tile_size, j : j + tile_size] # tile result tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf] tile = self.decoder(tile_in) h, w = tile.shape[-2], tile.shape[-1] # blend tile result into output blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w] tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) return out @apply_forward_hook def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]: if self.use_slicing and x.shape[0] > 1: output = [ self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1) ] output = torch.cat(output) else: output = self._tiled_encode(x) if self.use_tiling else self.encoder(x) if not return_dict: return (output,) return AutoencoderTinyOutput(latents=output) @apply_forward_hook def decode( self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: if self.use_slicing and x.shape[0] > 1: output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] output = torch.cat(output) else: output = self._tiled_decode(x) if self.use_tiling else self.decoder(x) if not return_dict: return (output,) return DecoderOutput(sample=output) def forward( self, sample: torch.Tensor, return_dict: bool = True, ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: r""" Args: sample (`torch.Tensor`): Input sample. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ enc = self.encode(sample).latents # scale latents to be in [0, 1], then quantize latents to a byte tensor, # as if we were storing the latents in an RGBA uint8 image. scaled_enc = self.scale_latents(enc).mul_(255).round_().byte() # unquantize latents back into [0, 1], then unscale latents back to their original range, # as if we were loading the latents from an RGBA uint8 image. unscaled_enc = self.unscale_latents(scaled_enc / 255.0) dec = self.decode(unscaled_enc) if not return_dict: return (dec,) return DecoderOutput(sample=dec) ================================================ FILE: src/diffusers/models/autoencoders/consistency_decoder_vae.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...schedulers import ConsistencyDecoderScheduler from ...utils import BaseOutput from ...utils.accelerate_utils import apply_forward_hook from ...utils.torch_utils import randn_tensor from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) from ..modeling_utils import ModelMixin from ..unets.unet_2d import UNet2DModel from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder @dataclass class ConsistencyDecoderVAEOutput(BaseOutput): """ Output of encoding method. Args: latent_dist (`DiagonalGaussianDistribution`): Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. `DiagonalGaussianDistribution` allows for sampling latents from the distribution. """ latent_dist: "DiagonalGaussianDistribution" class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): r""" The consistency decoder used with DALL-E 3. Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPipeline, ConsistencyDecoderVAE >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16) >>> pipe = StableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16 ... ).to("cuda") >>> image = pipe("horse", generator=torch.manual_seed(0)).images[0] >>> image ``` """ @register_to_config def __init__( self, scaling_factor: float = 0.18215, latent_channels: int = 4, sample_size: int = 32, encoder_act_fn: str = "silu", encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), encoder_double_z: bool = True, encoder_down_block_types: Tuple[str, ...] = ( "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", ), encoder_in_channels: int = 3, encoder_layers_per_block: int = 2, encoder_norm_num_groups: int = 32, encoder_out_channels: int = 4, decoder_add_attention: bool = False, decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024), decoder_down_block_types: Tuple[str, ...] = ( "ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D", ), decoder_downsample_padding: int = 1, decoder_in_channels: int = 7, decoder_layers_per_block: int = 3, decoder_norm_eps: float = 1e-05, decoder_norm_num_groups: int = 32, decoder_num_train_timesteps: int = 1024, decoder_out_channels: int = 6, decoder_resnet_time_scale_shift: str = "scale_shift", decoder_time_embedding_type: str = "learned", decoder_up_block_types: Tuple[str, ...] = ( "ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D", ), ): super().__init__() self.encoder = Encoder( act_fn=encoder_act_fn, block_out_channels=encoder_block_out_channels, double_z=encoder_double_z, down_block_types=encoder_down_block_types, in_channels=encoder_in_channels, layers_per_block=encoder_layers_per_block, norm_num_groups=encoder_norm_num_groups, out_channels=encoder_out_channels, ) self.decoder_unet = UNet2DModel( add_attention=decoder_add_attention, block_out_channels=decoder_block_out_channels, down_block_types=decoder_down_block_types, downsample_padding=decoder_downsample_padding, in_channels=decoder_in_channels, layers_per_block=decoder_layers_per_block, norm_eps=decoder_norm_eps, norm_num_groups=decoder_norm_num_groups, num_train_timesteps=decoder_num_train_timesteps, out_channels=decoder_out_channels, resnet_time_scale_shift=decoder_resnet_time_scale_shift, time_embedding_type=decoder_time_embedding_type, up_block_types=decoder_up_block_types, ) self.decoder_scheduler = ConsistencyDecoderScheduler() self.register_to_config(block_out_channels=encoder_block_out_channels) self.register_to_config(force_upcast=False) self.register_buffer( "means", torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None], persistent=False, ) self.register_buffer( "stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.use_slicing = False self.use_tiling = False # only relevant if vae tiling is enabled self.tile_sample_min_size = self.config.sample_size sample_size = ( self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size ) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling def enable_tiling(self, use_tiling: bool = True): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.use_tiling = use_tiling # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling def disable_tiling(self): r""" Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.enable_tiling(False) # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing def enable_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing def disable_slicing(self): r""" Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.use_slicing = False @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain tuple. Returns: The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): return self.tiled_encode(x, return_dict=return_dict) if self.use_slicing and x.shape[0] > 1: encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return ConsistencyDecoderVAEOutput(latent_dist=posterior) @apply_forward_hook def decode( self, z: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, num_inference_steps: int = 2, ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: """ Decodes the input latent vector `z` using the consistency decoder VAE model. Args: z (torch.Tensor): The input latent vector. generator (Optional[torch.Generator]): The random number generator. Default is None. return_dict (bool): Whether to return the output as a dictionary. Default is True. num_inference_steps (int): The number of inference steps. Default is 2. Returns: Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output. """ z = (z * self.config.scaling_factor - self.means) / self.stds scale_factor = 2 ** (len(self.config.block_out_channels) - 1) z = F.interpolate(z, mode="nearest", scale_factor=scale_factor) batch_size, _, height, width = z.shape self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device) x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor( (batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device ) for t in self.decoder_scheduler.timesteps: model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1) model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :] prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample x_t = prev_sample x_0 = x_t if not return_dict: return (x_0,) return DecoderOutput(sample=x_0) # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[2], b.shape[2], blend_extent) for y in range(blend_extent): b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for x in range(blend_extent): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the output, but they should be much less noticeable. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain tuple. Returns: [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`: If return_dict is True, a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple` is returned. """ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent # Split the image into 512x512 tiles and encode them separately. rows = [] for i in range(0, x.shape[2], overlap_size): row = [] for j in range(0, x.shape[3], overlap_size): tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] tile = self.encoder(tile) tile = self.quant_conv(tile) row.append(tile) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) moments = torch.cat(result_rows, dim=2) posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return ConsistencyDecoderVAEOutput(latent_dist=posterior) def forward( self, sample: torch.Tensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: r""" Args: sample (`torch.Tensor`): Input sample. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. generator (`torch.Generator`, *optional*, defaults to `None`): Generator to use for sampling. Returns: [`DecoderOutput`] or `tuple`: If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z, generator=generator).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec) ================================================ FILE: src/diffusers/models/autoencoders/vae.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple import numpy as np import torch import torch.nn as nn from ...utils import BaseOutput, is_torch_version from ...utils.torch_utils import randn_tensor from ..activations import get_activation from ..attention_processor import SpatialNorm from ..unets.unet_2d_blocks import ( AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block, ) @dataclass class DecoderOutput(BaseOutput): r""" Output of decoding method. Args: sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): The decoded output sample from the last layer of the model. """ sample: torch.Tensor commit_loss: Optional[torch.FloatTensor] = None class Encoder(nn.Module): r""" The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. double_z (`bool`, *optional*, defaults to `True`): Whether to double the number of output channels for the last block. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", double_z: bool = True, mid_block_add_attention=True, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, ) self.down_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=self.layers_per_block, in_channels=input_channel, out_channels=output_channel, add_downsample=not is_final_block, resnet_eps=1e-6, downsample_padding=0, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=None, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, add_attention=mid_block_add_attention, ) # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) self.gradient_checkpointing = False def forward(self, sample: torch.Tensor) -> torch.Tensor: r"""The forward method of the `Encoder` class.""" sample = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # down if is_torch_version(">=", "1.11.0"): for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), sample, use_reentrant=False ) # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, use_reentrant=False ) else: for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) # middle sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) else: # down for down_block in self.down_blocks: sample = down_block(sample) # middle sample = self.mid_block(sample) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class Decoder(nn.Module): r""" The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. norm_type (`str`, *optional*, defaults to `"group"`): The normalization type to use. Can be either `"group"` or `"spatial"`. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", norm_type: str = "group", # group, spatial mid_block_add_attention=True, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = nn.Conv2d( in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1, ) self.up_blocks = nn.ModuleList([]) temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, add_attention=mid_block_add_attention, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, prev_output_channel=None, add_upsample=not is_final_block, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=temb_channels, resnet_time_scale_shift=norm_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.gradient_checkpointing = False def forward( self, sample: torch.Tensor, latent_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""The forward method of the `Decoder` class.""" sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False, ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False, ) else: # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) else: # middle sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = up_block(sample, latent_embeds) # post-process if latent_embeds is None: sample = self.conv_norm_out(sample) else: sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class UpSample(nn.Module): r""" The `UpSample` layer of a variational autoencoder that upsamples its input. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. """ def __init__( self, in_channels: int, out_channels: int, ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) def forward(self, x: torch.Tensor) -> torch.Tensor: r"""The forward method of the `UpSample` class.""" x = torch.relu(x) x = self.deconv(x) return x class MaskConditionEncoder(nn.Module): """ used in AsymmetricAutoencoderKL """ def __init__( self, in_ch: int, out_ch: int = 192, res_ch: int = 768, stride: int = 16, ) -> None: super().__init__() channels = [] while stride > 1: stride = stride // 2 in_ch_ = out_ch * 2 if out_ch > res_ch: out_ch = res_ch if stride == 1: in_ch_ = res_ch channels.append((in_ch_, out_ch)) out_ch *= 2 out_channels = [] for _in_ch, _out_ch in channels: out_channels.append(_out_ch) out_channels.append(channels[-1][0]) layers = [] in_ch_ = in_ch for l in range(len(out_channels)): out_ch_ = out_channels[l] if l == 0 or l == 1: layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)) else: layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)) in_ch_ = out_ch_ self.layers = nn.Sequential(*layers) def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor: r"""The forward method of the `MaskConditionEncoder` class.""" out = {} for l in range(len(self.layers)): layer = self.layers[l] x = layer(x) out[str(tuple(x.shape))] = x x = torch.relu(x) return out class MaskConditionDecoder(nn.Module): r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's decoder with a conditioner on the mask and masked image. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. norm_type (`str`, *optional*, defaults to `"group"`): The normalization type to use. Can be either `"group"` or `"spatial"`. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", norm_type: str = "group", # group, spatial ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = nn.Conv2d( in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1, ) self.up_blocks = nn.ModuleList([]) temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, prev_output_channel=None, add_upsample=not is_final_block, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=temb_channels, resnet_time_scale_shift=norm_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # condition encoder self.condition_encoder = MaskConditionEncoder( in_ch=out_channels, out_ch=block_out_channels[0], res_ch=block_out_channels[-1], ) # out if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.gradient_checkpointing = False def forward( self, z: torch.Tensor, image: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, latent_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""The forward method of the `MaskConditionDecoder` class.""" sample = z sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False, ) sample = sample.to(upscale_dtype) # condition encoder if image is not None and mask is not None: masked_image = (1 - mask) * image im_x = torch.utils.checkpoint.checkpoint( create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False, ) # up for up_block in self.up_blocks: if image is not None and mask is not None: sample_ = im_x[str(tuple(sample.shape))] mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") sample = sample * mask_ + sample_ * (1 - mask_) sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False, ) if image is not None and mask is not None: sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) else: # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds ) sample = sample.to(upscale_dtype) # condition encoder if image is not None and mask is not None: masked_image = (1 - mask) * image im_x = torch.utils.checkpoint.checkpoint( create_custom_forward(self.condition_encoder), masked_image, mask, ) # up for up_block in self.up_blocks: if image is not None and mask is not None: sample_ = im_x[str(tuple(sample.shape))] mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") sample = sample * mask_ + sample_ * (1 - mask_) sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) if image is not None and mask is not None: sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) else: # middle sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) # condition encoder if image is not None and mask is not None: masked_image = (1 - mask) * image im_x = self.condition_encoder(masked_image, mask) # up for up_block in self.up_blocks: if image is not None and mask is not None: sample_ = im_x[str(tuple(sample.shape))] mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") sample = sample * mask_ + sample_ * (1 - mask_) sample = up_block(sample, latent_embeds) if image is not None and mask is not None: sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) # post-process if latent_embeds is None: sample = self.conv_norm_out(sample) else: sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class VectorQuantizer(nn.Module): """ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix multiplications and allows for post-hoc remapping of indices. """ # NOTE: due to a bug the beta term was applied to the wrong term. for # backwards compatibility we use the buggy version by default, but you can # specify legacy=False to fix it. def __init__( self, n_e: int, vq_embed_dim: int, beta: float, remap=None, unknown_index: str = "random", sane_index_shape: bool = False, legacy: bool = True, ): super().__init__() self.n_e = n_e self.vq_embed_dim = vq_embed_dim self.beta = beta self.legacy = legacy self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) self.remap = remap if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.used: torch.Tensor self.re_embed = self.used.shape[0] self.unknown_index = unknown_index # "random" or "extra" or integer if self.unknown_index == "extra": self.unknown_index = self.re_embed self.re_embed = self.re_embed + 1 print( f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices." ) else: self.re_embed = n_e self.sane_index_shape = sane_index_shape def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor: ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) match = (inds[:, :, None] == used[None, None, ...]).long() new = match.argmax(-1) unknown = match.sum(2) < 1 if self.unknown_index == "random": new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) else: new[unknown] = self.unknown_index return new.reshape(ishape) def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor: ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) if self.re_embed > self.used.shape[0]: # extra token inds[inds >= self.used.shape[0]] = 0 # simply set to zero back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) return back.reshape(ishape) def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]: # reshape z -> (batch, height, width, channel) and flatten z = z.permute(0, 2, 3, 1).contiguous() z_flattened = z.view(-1, self.vq_embed_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) z_q = self.embedding(min_encoding_indices).view(z.shape) perplexity = None min_encodings = None # compute loss for embedding if not self.legacy: loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) else: loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) # preserve gradients z_q: torch.Tensor = z + (z_q - z).detach() # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() if self.remap is not None: min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis min_encoding_indices = self.remap_to_used(min_encoding_indices) min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten if self.sane_index_shape: min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) return z_q, loss, (perplexity, min_encodings, min_encoding_indices) def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor: # shape specifying (batch, height, width, channel) if self.remap is not None: indices = indices.reshape(shape[0], -1) # add batch axis indices = self.unmap_to_all(indices) indices = indices.reshape(-1) # flatten again # get quantized latent vectors z_q: torch.Tensor = self.embedding(indices) if shape is not None: z_q = z_q.view(shape) # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() return z_q class DiagonalGaussianDistribution(object): def __init__(self, parameters: torch.Tensor, deterministic: bool = False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like( self.mean, device=self.parameters.device, dtype=self.parameters.dtype ) def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: # make sure sample is on the same device as the parameters and has same dtype sample = randn_tensor( self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype, ) x = self.mean + self.std * sample return x def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3], ) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3], ) def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) def mode(self) -> torch.Tensor: return self.mean class EncoderTiny(nn.Module): r""" The `EncoderTiny` layer is a simpler version of the `Encoder` layer. Args: in_channels (`int`): The number of input channels. out_channels (`int`): The number of output channels. num_blocks (`Tuple[int, ...]`): Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to use. block_out_channels (`Tuple[int, ...]`): The number of output channels for each block. act_fn (`str`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. """ def __init__( self, in_channels: int, out_channels: int, num_blocks: Tuple[int, ...], block_out_channels: Tuple[int, ...], act_fn: str, ): super().__init__() layers = [] for i, num_block in enumerate(num_blocks): num_channels = block_out_channels[i] if i == 0: layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)) else: layers.append( nn.Conv2d( num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False, ) ) for _ in range(num_block): layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)) self.layers = nn.Sequential(*layers) self.gradient_checkpointing = False def forward(self, x: torch.Tensor) -> torch.Tensor: r"""The forward method of the `EncoderTiny` class.""" if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) else: x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) else: # scale image from [-1, 1] to [0, 1] to match TAESD convention x = self.layers(x.add(1).div(2)) return x class DecoderTiny(nn.Module): r""" The `DecoderTiny` layer is a simpler version of the `Decoder` layer. Args: in_channels (`int`): The number of input channels. out_channels (`int`): The number of output channels. num_blocks (`Tuple[int, ...]`): Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to use. block_out_channels (`Tuple[int, ...]`): The number of output channels for each block. upsampling_scaling_factor (`int`): The scaling factor to use for upsampling. act_fn (`str`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. """ def __init__( self, in_channels: int, out_channels: int, num_blocks: Tuple[int, ...], block_out_channels: Tuple[int, ...], upsampling_scaling_factor: int, act_fn: str, upsample_fn: str, ): super().__init__() layers = [ nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1), get_activation(act_fn), ] for i, num_block in enumerate(num_blocks): is_final_block = i == (len(num_blocks) - 1) num_channels = block_out_channels[i] for _ in range(num_block): layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) if not is_final_block: layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn)) conv_out_channel = num_channels if not is_final_block else out_channels layers.append( nn.Conv2d( num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block, ) ) self.layers = nn.Sequential(*layers) self.gradient_checkpointing = False def forward(self, x: torch.Tensor) -> torch.Tensor: r"""The forward method of the `DecoderTiny` class.""" # Clamp. x = torch.tanh(x / 3) * 3 if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) else: x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) else: x = self.layers(x) # scale image from [0, 1] to [-1, 1] to match diffusers convention return x.mul(2).sub(1) ================================================ FILE: src/diffusers/models/autoencoders/vq_model.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ...utils.accelerate_utils import apply_forward_hook from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer from ..modeling_utils import ModelMixin @dataclass class VQEncoderOutput(BaseOutput): """ Output of VQModel encoding method. Args: latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): The encoded output sample from the last layer of the model. """ latents: torch.Tensor class VQModel(ModelMixin, ConfigMixin): r""" A VQ-VAE model for decoding latent representations. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers. vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. scaling_factor (`float`, *optional*, defaults to `0.18215`): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. norm_type (`str`, *optional*, defaults to `"group"`): Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. """ @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 3, sample_size: int = 32, num_vq_embeddings: int = 256, norm_num_groups: int = 32, vq_embed_dim: Optional[int] = None, scaling_factor: float = 0.18215, norm_type: str = "group", # group, spatial mid_block_add_attention=True, lookup_from_codebook=False, force_upcast=False, ): super().__init__() # pass init params to Encoder self.encoder = Encoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=False, mid_block_add_attention=mid_block_add_attention, ) vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) # pass init params to Decoder self.decoder = Decoder( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, norm_type=norm_type, mid_block_add_attention=mid_block_add_attention, ) @apply_forward_hook def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput: h = self.encoder(x) h = self.quant_conv(h) if not return_dict: return (h,) return VQEncoderOutput(latents=h) @apply_forward_hook def decode( self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None ) -> Union[DecoderOutput, torch.Tensor]: # also go through quantization layer if not force_not_quantize: quant, commit_loss, _ = self.quantize(h) elif self.config.lookup_from_codebook: quant = self.quantize.get_codebook_entry(h, shape) commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) else: quant = h commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) quant2 = self.post_quant_conv(quant) dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) if not return_dict: return dec, commit_loss return DecoderOutput(sample=dec, commit_loss=commit_loss) def forward( self, sample: torch.Tensor, return_dict: bool = True ) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]: r""" The [`VQModel`] forward method. Args: sample (`torch.Tensor`): Input sample. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.autoencoders.vq_model.VQEncoderOutput`] instead of a plain tuple. Returns: [`~models.autoencoders.vq_model.VQEncoderOutput`] or `tuple`: If return_dict is True, a [`~models.autoencoders.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` is returned. """ h = self.encode(sample).latents dec = self.decode(h) if not return_dict: return dec.sample, dec.commit_loss return dec ================================================ FILE: src/diffusers/models/controlnet.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config from ..loaders.single_file_model import FromOriginalModelMixin from ..utils import BaseOutput, logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unets.unet_2d_blocks import ( CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block, ) from .unets.unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class ControlNetOutput(BaseOutput): """ The output of [`ControlNetModel`]. Args: down_block_res_samples (`tuple[torch.Tensor]`): A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be used to condition the original UNet's downsampling activations. mid_down_block_re_sample (`torch.Tensor`): The activation of the middle block (the lowest sample resolution). Each tensor should be of shape `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. Output can be used to condition the original UNet's middle block activation. """ down_block_res_samples: Tuple[torch.Tensor] mid_block_res_sample: torch.Tensor class ControlNetConditioningEmbedding(nn.Module): """ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full model) to encode image-space conditions ... into feature maps ..." """ def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), ): super().__init__() self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ A ControlNet model. Args: in_channels (`int`, defaults to 4): The number of channels in the input sample. flip_sin_to_cos (`bool`, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, defaults to 2): The number of layers per block. downsample_padding (`int`, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, defaults to 1): The scale factor to use for the mid block. act_fn (`str`, defaults to "silu"): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If None, normalization and activation layers is skipped in post-processing. norm_eps (`float`, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): The dimension of the attention heads. use_linear_projection (`bool`, defaults to `False`): class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. num_class_embeds (`int`, *optional*, defaults to 0): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. upcast_attention (`bool`, defaults to `False`): resnet_time_scale_shift (`str`, defaults to `"default"`): Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `conditioning_embedding` layer. global_pool_conditions (`bool`, defaults to `False`): TODO(Patrick) - unused parameter. addition_embed_type_num_heads (`int`, defaults to 64): The number of heads to use for the `TextTimeEmbedding` layer. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 4, conditioning_channels: int = 3, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int, ...]] = 8, num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), global_pool_conditions: bool = False, addition_embed_type_num_heads: int = 64, ): super().__init__() # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") # control net conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) self.down_blocks = nn.ModuleList([]) self.controlnet_down_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) self.down_blocks.append(down_block) for _ in range(layers_per_block): controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) # mid mid_block_channel = block_out_channels[-1] controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) elif mid_block_type == "UNetMidBlock2D": self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, num_layers=0, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, add_attention=False, ) else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, conditioning_channels: int = 3, ): r""" Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. Parameters: unet (`UNet2DConditionModel`): The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied where applicable. """ transformer_layers_per_block = ( unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 ) encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None addition_time_embed_dim = ( unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None ) controlnet = cls( encoder_hid_dim=encoder_hid_dim, encoder_hid_dim_type=encoder_hid_dim_type, addition_embed_type=addition_embed_type, addition_time_embed_dim=addition_time_embed_dim, transformer_layers_per_block=transformer_layers_per_block, in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, down_block_types=unet.config.down_block_types, only_cross_attention=unet.config.only_cross_attention, block_out_channels=unet.config.block_out_channels, layers_per_block=unet.config.layers_per_block, downsample_padding=unet.config.downsample_padding, mid_block_scale_factor=unet.config.mid_block_scale_factor, act_fn=unet.config.act_fn, norm_num_groups=unet.config.norm_num_groups, norm_eps=unet.config.norm_eps, cross_attention_dim=unet.config.cross_attention_dim, attention_head_dim=unet.config.attention_head_dim, num_attention_heads=unet.config.num_attention_heads, use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, upcast_attention=unet.config.upcast_attention, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, mid_block_type=unet.config.mid_block_type, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, conditioning_embedding_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) if load_weights_from_unet: controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) if controlnet.class_embedding: controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) if hasattr(controlnet, "add_embedding"): controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) return controlnet @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: """ The [`ControlNetModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor. timestep (`Union[torch.Tensor, float, int]`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. controlnet_cond (`torch.Tensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep embeddings. attention_mask (`torch.Tensor`, *optional*, defaults to `None`): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. added_cond_kwargs (`dict`): Additional conditions for the Stable Diffusion XL UNet. cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. guess_mode (`bool`, defaults to `False`): In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: [`~models.controlnet.ControlNetOutput`] **or** `tuple`: If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ # check channel order channel_order = self.config.controlnet_conditioning_channel_order if channel_order == "rgb": # in rgb order by default ... elif channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) else: raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb if self.config.addition_embed_type is not None: if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) emb = emb + aug_emb if aug_emb is not None else emb # 2. pre-process sample = self.conv_in(sample) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample = sample + controlnet_cond # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid if self.mid_block is not None: if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, ) else: sample = self.mid_block(sample, emb) # 5. Control net blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples ] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: return (down_block_res_samples, mid_block_res_sample) return ControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module ================================================ FILE: src/diffusers/models/controlnet_flax.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..utils import BaseOutput from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .modeling_flax_utils import FlaxModelMixin from .unets.unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxDownBlock2D, FlaxUNetMidBlock2DCrossAttn, ) @flax.struct.dataclass class FlaxControlNetOutput(BaseOutput): """ The output of [`FlaxControlNetModel`]. Args: down_block_res_samples (`jnp.ndarray`): mid_block_res_sample (`jnp.ndarray`): """ down_block_res_samples: jnp.ndarray mid_block_res_sample: jnp.ndarray class FlaxControlNetConditioningEmbedding(nn.Module): conditioning_embedding_channels: int block_out_channels: Tuple[int, ...] = (16, 32, 96, 256) dtype: jnp.dtype = jnp.float32 def setup(self) -> None: self.conv_in = nn.Conv( self.block_out_channels[0], kernel_size=(3, 3), padding=((1, 1), (1, 1)), dtype=self.dtype, ) blocks = [] for i in range(len(self.block_out_channels) - 1): channel_in = self.block_out_channels[i] channel_out = self.block_out_channels[i + 1] conv1 = nn.Conv( channel_in, kernel_size=(3, 3), padding=((1, 1), (1, 1)), dtype=self.dtype, ) blocks.append(conv1) conv2 = nn.Conv( channel_out, kernel_size=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)), dtype=self.dtype, ) blocks.append(conv2) self.blocks = blocks self.conv_out = nn.Conv( self.conditioning_embedding_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray: embedding = self.conv_in(conditioning) embedding = nn.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = nn.silu(embedding) embedding = self.conv_out(embedding) return embedding @flax_register_to_config class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" A ControlNet model. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods implemented for all models (such as downloading or saving). This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its general usage and behavior. Inherent JAX features such as the following are supported: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: sample_size (`int`, *optional*): The size of the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): The tuple of downsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int` or `Tuple[int]`, *optional*): The number of attention heads. cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `conditioning_embedding` layer. """ sample_size: int = 32 in_channels: int = 4 down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ) only_cross_attention: Union[bool, Tuple[bool, ...]] = False block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: Union[int, Tuple[int, ...]] = 8 num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 flip_sin_to_cos: bool = True freq_shift: int = 0 controlnet_conditioning_channel_order: str = "rgb" conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256) def init_weights(self, rng: jax.Array) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8) controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"] def setup(self) -> None: block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = self.num_attention_heads or self.attention_head_dim # input self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # time self.time_proj = FlaxTimesteps( block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift ) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=self.conditioning_embedding_out_channels, ) only_cross_attention = self.only_cross_attention if isinstance(only_cross_attention, bool): only_cross_attention = (only_cross_attention,) * len(self.down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(self.down_block_types) # down down_blocks = [] controlnet_down_blocks = [] output_channel = block_out_channels[0] controlnet_block = nn.Conv( output_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 if down_block_type == "CrossAttnDownBlock2D": down_block = FlaxCrossAttnDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, num_attention_heads=num_attention_heads[i], add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], dtype=self.dtype, ) else: down_block = FlaxDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) for _ in range(self.layers_per_block): controlnet_block = nn.Conv( output_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = nn.Conv( output_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) controlnet_down_blocks.append(controlnet_block) self.down_blocks = down_blocks self.controlnet_down_blocks = controlnet_down_blocks # mid mid_block_channel = block_out_channels[-1] self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=mid_block_channel, dropout=self.dropout, num_attention_heads=num_attention_heads[-1], use_linear_projection=self.use_linear_projection, dtype=self.dtype, ) self.controlnet_mid_block = nn.Conv( mid_block_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) def __call__( self, sample: jnp.ndarray, timesteps: Union[jnp.ndarray, float, int], encoder_hidden_states: jnp.ndarray, controlnet_cond: jnp.ndarray, conditioning_scale: float = 1.0, return_dict: bool = True, train: bool = False, ) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]: r""" Args: sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ channel_order = self.controlnet_conditioning_channel_order if channel_order == "bgr": controlnet_cond = jnp.flip(controlnet_cond, axis=1) # 1. time if not isinstance(timesteps, jnp.ndarray): timesteps = jnp.array([timesteps], dtype=jnp.int32) elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: timesteps = timesteps.astype(dtype=jnp.float32) timesteps = jnp.expand_dims(timesteps, 0) t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) # 2. pre-process sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1)) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample += controlnet_cond # 3. down down_block_res_samples = (sample,) for down_block in self.down_blocks: if isinstance(down_block, FlaxCrossAttnDownBlock2D): sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) else: sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples # 4. mid sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) # 5. contronet blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample *= conditioning_scale if not return_dict: return (down_block_res_samples, mid_block_res_sample) return FlaxControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) ================================================ FILE: src/diffusers/models/controlnet_hunyuan.py ================================================ # Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Dict, Optional, Union import torch from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import logging from .attention_processor import AttentionProcessor from .controlnet import BaseOutput, Tuple, zero_module from .embeddings import ( HunyuanCombinedTimestepTextSizeStyleEmbedding, PatchEmbed, PixArtAlphaTextProjection, ) from .modeling_utils import ModelMixin from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class HunyuanControlNetOutput(BaseOutput): controlnet_block_samples: Tuple[torch.Tensor] class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, conditioning_channels: int = 3, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, patch_size: Optional[int] = None, activation_fn: str = "gelu-approximate", sample_size=32, hidden_size=1152, transformer_num_layers: int = 40, mlp_ratio: float = 4.0, cross_attention_dim: int = 1024, cross_attention_dim_t5: int = 2048, pooled_projection_dim: int = 1024, text_len: int = 77, text_len_t5: int = 256, use_style_cond_and_image_meta_size: bool = True, ): super().__init__() self.num_heads = num_attention_heads self.inner_dim = num_attention_heads * attention_head_dim self.text_embedder = PixArtAlphaTextProjection( in_features=cross_attention_dim_t5, hidden_size=cross_attention_dim_t5 * 4, out_features=cross_attention_dim, act_fn="silu_fp32", ) self.text_embedding_padding = nn.Parameter( torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32) ) self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, in_channels=in_channels, embed_dim=hidden_size, patch_size=patch_size, pos_embed_type=None, ) self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding( hidden_size, pooled_projection_dim=pooled_projection_dim, seq_len=text_len_t5, cross_attention_dim=cross_attention_dim_t5, use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size, ) # controlnet_blocks self.controlnet_blocks = nn.ModuleList([]) # HunyuanDiT Blocks self.blocks = nn.ModuleList( [ HunyuanDiTBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, activation_fn=activation_fn, ff_inner_dim=int(self.inner_dim * mlp_ratio), cross_attention_dim=cross_attention_dim, qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. skip=False, # always False as it is the first half of the model ) for layer in range(transformer_num_layers // 2 - 1) ] ) self.input_block = zero_module(nn.Linear(hidden_size, hidden_size)) for _ in range(len(self.blocks)): controlnet_block = nn.Linear(hidden_size, hidden_size) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) @classmethod def from_transformer( cls, transformer, conditioning_channels=3, transformer_num_layers=None, load_weights_from_transformer=True ): config = transformer.config activation_fn = config.activation_fn attention_head_dim = config.attention_head_dim cross_attention_dim = config.cross_attention_dim cross_attention_dim_t5 = config.cross_attention_dim_t5 hidden_size = config.hidden_size in_channels = config.in_channels mlp_ratio = config.mlp_ratio num_attention_heads = config.num_attention_heads patch_size = config.patch_size sample_size = config.sample_size text_len = config.text_len text_len_t5 = config.text_len_t5 conditioning_channels = conditioning_channels transformer_num_layers = transformer_num_layers or config.transformer_num_layers controlnet = cls( conditioning_channels=conditioning_channels, transformer_num_layers=transformer_num_layers, activation_fn=activation_fn, attention_head_dim=attention_head_dim, cross_attention_dim=cross_attention_dim, cross_attention_dim_t5=cross_attention_dim_t5, hidden_size=hidden_size, in_channels=in_channels, mlp_ratio=mlp_ratio, num_attention_heads=num_attention_heads, patch_size=patch_size, sample_size=sample_size, text_len=text_len, text_len_t5=text_len_t5, ) if load_weights_from_transformer: key = controlnet.load_state_dict(transformer.state_dict(), strict=False) logger.warning(f"controlnet load from Hunyuan-DiT. missing_keys: {key[0]}") return controlnet def forward( self, hidden_states, timestep, controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, encoder_hidden_states=None, text_embedding_mask=None, encoder_hidden_states_t5=None, text_embedding_mask_t5=None, image_meta_size=None, style=None, image_rotary_emb=None, return_dict=True, ): """ The [`HunyuanDiT2DControlNetModel`] forward method. Args: hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): The input tensor. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. controlnet_cond ( `torch.Tensor` ): The conditioning input to ControlNet. conditioning_scale ( `float` ): Indicate the conditioning scale. encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. This is the output of `BertModel`. text_embedding_mask: torch.Tensor An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output of `BertModel`. encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. text_embedding_mask_t5: torch.Tensor An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output of T5 Text Encoder. image_meta_size (torch.Tensor): Conditional embedding indicate the image sizes style: torch.Tensor: Conditional embedding indicate the style image_rotary_emb (`torch.Tensor`): The image rotary embeddings to apply on query and key tensors during attention calculation. return_dict: bool Whether to return a dictionary. """ height, width = hidden_states.shape[-2:] hidden_states = self.pos_embed(hidden_states) # b,c,H,W -> b, N, C # 2. pre-process hidden_states = hidden_states + self.input_block(self.pos_embed(controlnet_cond)) temb = self.time_extra_emb( timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype ) # [B, D] # text projection batch_size, sequence_length, _ = encoder_hidden_states_t5.shape encoder_hidden_states_t5 = self.text_embedder( encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1]) ) encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1) encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1) text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1) text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) block_res_samples = () for layer, block in enumerate(self.blocks): hidden_states = block( hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb, ) # (N, L, D) block_res_samples = block_res_samples + (hidden_states,) controlnet_block_res_samples = () for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): block_res_sample = controlnet_block(block_res_sample) controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) # 6. scaling controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] if not return_dict: return (controlnet_block_res_samples,) return HunyuanControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) class HunyuanDiT2DMultiControlNetModel(ModelMixin): r""" `HunyuanDiT2DMultiControlNetModel` wrapper class for Multi-HunyuanDiT2DControlNetModel This module is a wrapper for multiple instances of the `HunyuanDiT2DControlNetModel`. The `forward()` API is designed to be compatible with `HunyuanDiT2DControlNetModel`. Args: controlnets (`List[HunyuanDiT2DControlNetModel]`): Provides additional conditioning to the unet during the denoising process. You must set multiple `HunyuanDiT2DControlNetModel` as a list. """ def __init__(self, controlnets): super().__init__() self.nets = nn.ModuleList(controlnets) def forward( self, hidden_states, timestep, controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, encoder_hidden_states=None, text_embedding_mask=None, encoder_hidden_states_t5=None, text_embedding_mask_t5=None, image_meta_size=None, style=None, image_rotary_emb=None, return_dict=True, ): """ The [`HunyuanDiT2DControlNetModel`] forward method. Args: hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): The input tensor. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. controlnet_cond ( `torch.Tensor` ): The conditioning input to ControlNet. conditioning_scale ( `float` ): Indicate the conditioning scale. encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. This is the output of `BertModel`. text_embedding_mask: torch.Tensor An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output of `BertModel`. encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. text_embedding_mask_t5: torch.Tensor An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output of T5 Text Encoder. image_meta_size (torch.Tensor): Conditional embedding indicate the image sizes style: torch.Tensor: Conditional embedding indicate the style image_rotary_emb (`torch.Tensor`): The image rotary embeddings to apply on query and key tensors during attention calculation. return_dict: bool Whether to return a dictionary. """ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): block_samples = controlnet( hidden_states=hidden_states, timestep=timestep, controlnet_cond=image, conditioning_scale=scale, encoder_hidden_states=encoder_hidden_states, text_embedding_mask=text_embedding_mask, encoder_hidden_states_t5=encoder_hidden_states_t5, text_embedding_mask_t5=text_embedding_mask_t5, image_meta_size=image_meta_size, style=style, image_rotary_emb=image_rotary_emb, return_dict=return_dict, ) # merge samples if i == 0: control_block_samples = block_samples else: control_block_samples = [ control_block_sample + block_sample for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0]) ] control_block_samples = (control_block_samples,) return control_block_samples ================================================ FILE: src/diffusers/models/controlnet_sd3.py ================================================ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import FromOriginalModelMixin, PeftAdapterMixin from ..models.attention import JointTransformerBlock from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ..models.modeling_outputs import Transformer2DModelOutput from ..models.modeling_utils import ModelMixin from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from .controlnet import BaseOutput, zero_module from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class SD3ControlNetOutput(BaseOutput): controlnet_block_samples: Tuple[torch.Tensor] class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: int = 128, patch_size: int = 2, in_channels: int = 16, num_layers: int = 18, attention_head_dim: int = 64, num_attention_heads: int = 18, joint_attention_dim: int = 4096, caption_projection_dim: int = 1152, pooled_projection_dim: int = 2048, out_channels: int = 16, pos_embed_max_size: int = 96, ): super().__init__() default_out_channels = in_channels self.out_channels = out_channels if out_channels is not None else default_out_channels self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=self.inner_dim, pos_embed_max_size=pos_embed_max_size, ) self.time_text_embed = CombinedTimestepTextProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) # `attention_head_dim` is doubled to account for the mixing. # It needs to crafted when we get the actual checkpoints. self.transformer_blocks = nn.ModuleList( [ JointTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=self.config.attention_head_dim, context_pre_only=False, ) for i in range(num_layers) ] ) # controlnet_blocks self.controlnet_blocks = nn.ModuleList([]) for _ in range(len(self.transformer_blocks)): controlnet_block = nn.Linear(self.inner_dim, self.inner_dim) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) pos_embed_input = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=self.inner_dim, pos_embed_type=None, ) self.pos_embed_input = zero_module(pos_embed_input) self.gradient_checkpointing = False # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedJointAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @classmethod def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True): config = transformer.config config["num_layers"] = num_layers or config.num_layers controlnet = cls(**config) if load_weights_from_transformer: controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input) return controlnet def forward( self, hidden_states: torch.FloatTensor, controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, encoder_hidden_states: torch.FloatTensor = None, pooled_projections: torch.FloatTensor = None, timestep: torch.LongTensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`SD3Transformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. controlnet_cond (`torch.Tensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) # add hidden_states = hidden_states + self.pos_embed_input(controlnet_cond) block_res_samples = () for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) block_res_samples = block_res_samples + (hidden_states,) controlnet_block_res_samples = () for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): block_res_sample = controlnet_block(block_res_sample) controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) # 6. scaling controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (controlnet_block_res_samples,) return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) class SD3MultiControlNetModel(ModelMixin): r""" `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be compatible with `SD3ControlNetModel`. Args: controlnets (`List[SD3ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. You must set multiple `SD3ControlNetModel` as a list. """ def __init__(self, controlnets): super().__init__() self.nets = nn.ModuleList(controlnets) def forward( self, hidden_states: torch.FloatTensor, controlnet_cond: List[torch.tensor], conditioning_scale: List[float], pooled_projections: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[SD3ControlNetOutput, Tuple]: for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): block_samples = controlnet( hidden_states=hidden_states, timestep=timestep, encoder_hidden_states=encoder_hidden_states, pooled_projections=pooled_projections, controlnet_cond=image, conditioning_scale=scale, joint_attention_kwargs=joint_attention_kwargs, return_dict=return_dict, ) # merge samples if i == 0: control_block_samples = block_samples else: control_block_samples = [ control_block_sample + block_sample for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0]) ] control_block_samples = (tuple(control_block_samples),) return control_block_samples ================================================ FILE: src/diffusers/models/controlnet_sparsectrl.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import FromOriginalModelMixin from ..utils import BaseOutput, logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn from .unets.unet_2d_condition import UNet2DConditionModel from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class SparseControlNetOutput(BaseOutput): """ The output of [`SparseControlNetModel`]. Args: down_block_res_samples (`tuple[torch.Tensor]`): A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be used to condition the original UNet's downsampling activations. mid_down_block_re_sample (`torch.Tensor`): The activation of the middle block (the lowest sample resolution). Each tensor should be of shape `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. Output can be used to condition the original UNet's middle block activation. """ down_block_res_samples: Tuple[torch.Tensor] mid_block_res_sample: torch.Tensor class SparseControlNetConditioningEmbedding(nn.Module): def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), ): super().__init__() self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) def forward(self, conditioning: torch.Tensor) -> torch.Tensor: embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933). Args: in_channels (`int`, defaults to 4): The number of channels in the input sample. conditioning_channels (`int`, defaults to 4): The number of input channels in the controlnet conditional embedding module. If `concat_condition_embedding` is True, the value provided here is incremented by 1. flip_sin_to_cos (`bool`, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, defaults to 2): The number of layers per block. downsample_padding (`int`, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, defaults to 1): The scale factor to use for the mid block. act_fn (`str`, defaults to "silu"): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If None, normalization and activation layers is skipped in post-processing. norm_eps (`float`, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer layers to use in each layer in the middle block. attention_head_dim (`int` or `Tuple[int]`, defaults to 8): The dimension of the attention heads. num_attention_heads (`int` or `Tuple[int]`, *optional*): The number of heads to use for multi-head attention. use_linear_projection (`bool`, defaults to `False`): upcast_attention (`bool`, defaults to `False`): resnet_time_scale_shift (`str`, defaults to `"default"`): Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `conditioning_embedding` layer. global_pool_conditions (`bool`, defaults to `False`): TODO(Patrick) - unused parameter controlnet_conditioning_channel_order (`str`, defaults to `rgb`): motion_max_seq_length (`int`, defaults to `32`): The maximum sequence length to use in the motion module. motion_num_attention_heads (`int` or `Tuple[int]`, defaults to `8`): The number of heads to use in each attention layer of the motion module. concat_conditioning_mask (`bool`, defaults to `True`): use_simplified_condition_embedding (`bool`, defaults to `True`): """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 4, conditioning_channels: int = 4, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion", "DownBlockMotion", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 768, transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None, temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, attention_head_dim: Union[int, Tuple[int, ...]] = 8, num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, use_linear_projection: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), global_pool_conditions: bool = False, controlnet_conditioning_channel_order: str = "rgb", motion_max_seq_length: int = 32, motion_num_attention_heads: int = 8, concat_conditioning_mask: bool = True, use_simplified_condition_embedding: bool = True, ): super().__init__() self.use_simplified_condition_embedding = use_simplified_condition_embedding # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if isinstance(temporal_transformer_layers_per_block, int): temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types) # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) if concat_conditioning_mask: conditioning_channels = conditioning_channels + 1 self.concat_conditioning_mask = concat_conditioning_mask # control net conditioning embedding if use_simplified_condition_embedding: self.controlnet_cond_embedding = zero_module( nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) ) else: self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) self.down_blocks = nn.ModuleList([]) self.controlnet_down_blocks = nn.ModuleList([]) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(motion_num_attention_heads, int): motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 if down_block_type == "CrossAttnDownBlockMotion": down_block = CrossAttnDownBlockMotion( in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, dropout=0, num_layers=layers_per_block, transformer_layers_per_block=transformer_layers_per_block[i], resnet_eps=norm_eps, resnet_time_scale_shift=resnet_time_scale_shift, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, resnet_pre_norm=True, num_attention_heads=num_attention_heads[i], cross_attention_dim=cross_attention_dim[i], add_downsample=not is_final_block, dual_cross_attention=False, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], temporal_double_self_attention=False, ) elif down_block_type == "DownBlockMotion": down_block = DownBlockMotion( in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, dropout=0, num_layers=layers_per_block, resnet_eps=norm_eps, resnet_time_scale_shift=resnet_time_scale_shift, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, resnet_pre_norm=True, add_downsample=not is_final_block, temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], temporal_double_self_attention=False, ) else: raise ValueError( "Invalid `block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" ) self.down_blocks.append(down_block) for _ in range(layers_per_block): controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) # mid mid_block_channels = block_out_channels[-1] controlnet_block = nn.Conv2d(mid_block_channels, mid_block_channels, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block if transformer_layers_per_mid_block is None: transformer_layers_per_mid_block = ( transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1 ) self.mid_block = UNetMidBlock2DCrossAttn( in_channels=mid_block_channels, temb_channels=time_embed_dim, dropout=0, num_layers=1, transformer_layers_per_block=transformer_layers_per_mid_block, resnet_eps=norm_eps, resnet_time_scale_shift=resnet_time_scale_shift, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, resnet_pre_norm=True, num_attention_heads=num_attention_heads[-1], output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], dual_cross_attention=False, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type="default", ) @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, conditioning_channels: int = 3, ) -> "SparseControlNetModel": r""" Instantiate a [`SparseControlNetModel`] from [`UNet2DConditionModel`]. Parameters: unet (`UNet2DConditionModel`): The UNet model weights to copy to the [`SparseControlNetModel`]. All configuration options are also copied where applicable. """ transformer_layers_per_block = ( unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 ) down_block_types = unet.config.down_block_types for i in range(len(down_block_types)): if "CrossAttn" in down_block_types[i]: down_block_types[i] = "CrossAttnDownBlockMotion" elif "Down" in down_block_types[i]: down_block_types[i] = "DownBlockMotion" else: raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block") controlnet = cls( in_channels=unet.config.in_channels, conditioning_channels=conditioning_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, down_block_types=unet.config.down_block_types, only_cross_attention=unet.config.only_cross_attention, block_out_channels=unet.config.block_out_channels, layers_per_block=unet.config.layers_per_block, downsample_padding=unet.config.downsample_padding, mid_block_scale_factor=unet.config.mid_block_scale_factor, act_fn=unet.config.act_fn, norm_num_groups=unet.config.norm_num_groups, norm_eps=unet.config.norm_eps, cross_attention_dim=unet.config.cross_attention_dim, transformer_layers_per_block=transformer_layers_per_block, attention_head_dim=unet.config.attention_head_dim, num_attention_heads=unet.config.num_attention_heads, use_linear_projection=unet.config.use_linear_projection, upcast_attention=unet.config.upcast_attention, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, conditioning_embedding_out_channels=conditioning_embedding_out_channels, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, ) if load_weights_from_unet: controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False) controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False) controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False) controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False) controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False) return controlnet @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)): module.gradient_checkpointing = value def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, conditioning_mask: Optional[torch.Tensor] = None, guess_mode: bool = False, return_dict: bool = True, ) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: """ The [`SparseControlNetModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor. timestep (`Union[torch.Tensor, float, int]`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. controlnet_cond (`torch.Tensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep embeddings. attention_mask (`torch.Tensor`, *optional*, defaults to `None`): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. added_cond_kwargs (`dict`): Additional conditions for the Stable Diffusion XL UNet. cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. guess_mode (`bool`, defaults to `False`): In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: [`~models.controlnet.ControlNetOutput`] **or** `tuple`: If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape sample = torch.zeros_like(sample) # check channel order channel_order = self.config.controlnet_conditioning_channel_order if channel_order == "rgb": # in rgb order by default ... elif channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) else: raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) emb = emb.repeat_interleave(sample_num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(sample_num_frames, dim=0) # 2. pre-process batch_size, channels, num_frames, height, width = sample.shape sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) sample = self.conv_in(sample) batch_frames, channels, height, width = sample.shape sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width) if self.concat_conditioning_mask: controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) batch_size, channels, num_frames, height, width = controlnet_cond.shape controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape( batch_size * num_frames, channels, height, width ) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) batch_frames, channels, height, width = controlnet_cond.shape controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width) sample = sample + controlnet_cond batch_size, num_frames, channels, height, width = sample.shape sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) down_block_res_samples += res_samples # 4. mid if self.mid_block is not None: if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, ) else: sample = self.mid_block(sample, emb) # 5. Control net blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples ] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: return (down_block_res_samples, mid_block_res_sample) return SparseControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) # Copied from diffusers.models.controlnet.zero_module def zero_module(module: nn.Module) -> nn.Module: for p in module.parameters(): nn.init.zeros_(p) return module ================================================ FILE: src/diffusers/models/controlnet_xs.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from math import gcd from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import Tensor, nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, is_torch_version, logging from ..utils.torch_utils import apply_freeu from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, FusedAttnProcessor2_0, ) from .controlnet import ControlNetConditioningEmbedding from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unets.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, Downsample2D, ResnetBlock2D, Transformer2DModel, UNetMidBlock2DCrossAttn, Upsample2D, ) from .unets.unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class ControlNetXSOutput(BaseOutput): """ The output of [`UNetControlNetXSModel`]. Args: sample (`Tensor` of shape `(batch_size, num_channels, height, width)`): The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model output, but is already the final output. """ sample: Tensor = None class DownBlockControlNetXSAdapter(nn.Module): """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnDownBlock2D`""" def __init__( self, resnets: nn.ModuleList, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList, attentions: Optional[nn.ModuleList] = None, downsampler: Optional[nn.Conv2d] = None, ): super().__init__() self.resnets = resnets self.base_to_ctrl = base_to_ctrl self.ctrl_to_base = ctrl_to_base self.attentions = attentions self.downsamplers = downsampler class MidBlockControlNetXSAdapter(nn.Module): """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnMidBlock2D`""" def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList): super().__init__() self.midblock = midblock self.base_to_ctrl = base_to_ctrl self.ctrl_to_base = ctrl_to_base class UpBlockControlNetXSAdapter(nn.Module): """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`""" def __init__(self, ctrl_to_base: nn.ModuleList): super().__init__() self.ctrl_to_base = ctrl_to_base def get_down_block_adapter( base_in_channels: int, base_out_channels: int, ctrl_in_channels: int, ctrl_out_channels: int, temb_channels: int, max_norm_num_groups: Optional[int] = 32, has_crossattn=True, transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, add_downsample: bool = True, upcast_attention: Optional[bool] = False, use_linear_projection: Optional[bool] = True, ): num_layers = 2 # only support sd + sdxl resnets = [] attentions = [] ctrl_to_base = [] base_to_ctrl = [] if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): base_in_channels = base_in_channels if i == 0 else base_out_channels ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels # Before the resnet/attention application, information is concatted from base to control. # Concat doesn't require change in number of channels base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) resnets.append( ResnetBlock2D( in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl out_channels=ctrl_out_channels, temb_channels=temb_channels, groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), eps=1e-5, ) ) if has_crossattn: attentions.append( Transformer2DModel( num_attention_heads, ctrl_out_channels // num_attention_heads, in_channels=ctrl_out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), ) ) # After the resnet/attention application, information is added from control to base # Addition requires change in number of channels ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) if add_downsample: # Before the downsampler application, information is concatted from base to control # Concat doesn't require change in number of channels base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) downsamplers = Downsample2D( ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" ) # After the downsampler application, information is added from control to base # Addition requires change in number of channels ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) else: downsamplers = None down_block_components = DownBlockControlNetXSAdapter( resnets=nn.ModuleList(resnets), base_to_ctrl=nn.ModuleList(base_to_ctrl), ctrl_to_base=nn.ModuleList(ctrl_to_base), ) if has_crossattn: down_block_components.attentions = nn.ModuleList(attentions) if downsamplers is not None: down_block_components.downsamplers = downsamplers return down_block_components def get_mid_block_adapter( base_channels: int, ctrl_channels: int, temb_channels: Optional[int] = None, max_norm_num_groups: Optional[int] = 32, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, upcast_attention: bool = False, use_linear_projection: bool = True, ): # Before the midblock application, information is concatted from base to control. # Concat doesn't require change in number of channels base_to_ctrl = make_zero_conv(base_channels, base_channels) midblock = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block, in_channels=ctrl_channels + base_channels, out_channels=ctrl_channels, temb_channels=temb_channels, # number or norm groups must divide both in_channels and out_channels resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) # After the midblock application, information is added from control to base # Addition requires change in number of channels ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base) def get_up_block_adapter( out_channels: int, prev_output_channel: int, ctrl_skip_channels: List[int], ): ctrl_to_base = [] num_layers = 3 # only support sd + sdxl for i in range(num_layers): resnet_in_channels = prev_output_channel if i == 0 else out_channels ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base)) class ControlNetXSAdapter(ModelMixin, ConfigMixin): r""" A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a `UNet2DConditionModel` base model). This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Like `UNetControlNetXSModel`, `ControlNetXSAdapter` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are compatible with StableDiffusion. Parameters: conditioning_channels (`int`, defaults to 3): Number of channels of conditioning input (e.g. an image) conditioning_channel_order (`str`, defaults to `"rgb"`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): The tuple of output channels for each block in the `controlnet_cond_embedding` layer. time_embedding_mix (`float`, defaults to 1.0): If 0, then only the control adapters's time embedding is used. If 1, then only the base unet's time embedding is used. Otherwise, both are combined. learn_time_embedding (`bool`, defaults to `False`): Whether a time embedding should be learned. If yes, `UNetControlNetXSModel` will combine the time embeddings of the base model and the control adapter. If no, `UNetControlNetXSModel` will use the base model's time embedding. num_attention_heads (`list[int]`, defaults to `[4]`): The number of attention heads. block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): The tuple of output channels for each block. base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`): The tuple of output channels for each block in the base unet. cross_attention_dim (`int`, defaults to 1024): The dimension of the cross attention features. down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`): The tuple of downsample blocks to use. sample_size (`int`, defaults to 96): Height and width of input/output sample. transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. upcast_attention (`bool`, defaults to `True`): Whether the attention computation should always be upcasted. max_norm_num_groups (`int`, defaults to 32): Maximum number of groups in group normal. The actual number will be the largest divisor of the respective channels, that is <= max_norm_num_groups. """ @register_to_config def __init__( self, conditioning_channels: int = 3, conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), time_embedding_mix: float = 1.0, learn_time_embedding: bool = False, num_attention_heads: Union[int, Tuple[int]] = 4, block_out_channels: Tuple[int] = (4, 8, 16, 16), base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), cross_attention_dim: int = 1024, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), sample_size: Optional[int] = 96, transformer_layers_per_block: Union[int, Tuple[int]] = 1, upcast_attention: bool = True, max_norm_num_groups: int = 32, use_linear_projection: bool = True, ): super().__init__() time_embedding_input_dim = base_block_out_channels[0] time_embedding_dim = base_block_out_channels[0] * 4 # Check inputs if conditioning_channel_order not in ["rgb", "bgr"]: raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}") if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(transformer_layers_per_block, (list, tuple)): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if not isinstance(cross_attention_dim, (list, tuple)): cross_attention_dim = [cross_attention_dim] * len(down_block_types) # see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAdapter` takes `num_attention_heads` instead of `attention_head_dim` if not isinstance(num_attention_heads, (list, tuple)): num_attention_heads = [num_attention_heads] * len(down_block_types) if len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) # 5 - Create conditioning hint embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) # time if learn_time_embedding: self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim) else: self.time_embedding = None self.down_blocks = nn.ModuleList([]) self.up_connections = nn.ModuleList([]) # input self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0]) # down base_out_channels = base_block_out_channels[0] ctrl_out_channels = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): base_in_channels = base_out_channels base_out_channels = base_block_out_channels[i] ctrl_in_channels = ctrl_out_channels ctrl_out_channels = block_out_channels[i] has_crossattn = "CrossAttn" in down_block_type is_final_block = i == len(down_block_types) - 1 self.down_blocks.append( get_down_block_adapter( base_in_channels=base_in_channels, base_out_channels=base_out_channels, ctrl_in_channels=ctrl_in_channels, ctrl_out_channels=ctrl_out_channels, temb_channels=time_embedding_dim, max_norm_num_groups=max_norm_num_groups, has_crossattn=has_crossattn, transformer_layers_per_block=transformer_layers_per_block[i], num_attention_heads=num_attention_heads[i], cross_attention_dim=cross_attention_dim[i], add_downsample=not is_final_block, upcast_attention=upcast_attention, use_linear_projection=use_linear_projection, ) ) # mid self.mid_block = get_mid_block_adapter( base_channels=base_block_out_channels[-1], ctrl_channels=block_out_channels[-1], temb_channels=time_embedding_dim, transformer_layers_per_block=transformer_layers_per_block[-1], num_attention_heads=num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], upcast_attention=upcast_attention, use_linear_projection=use_linear_projection, ) # up # The skip connection channels are the output of the conv_in and of all the down subblocks ctrl_skip_channels = [block_out_channels[0]] for i, out_channels in enumerate(block_out_channels): number_of_subblocks = ( 3 if i < len(block_out_channels) - 1 else 2 ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler ctrl_skip_channels.extend([out_channels] * number_of_subblocks) reversed_base_block_out_channels = list(reversed(base_block_out_channels)) base_out_channels = reversed_base_block_out_channels[0] for i in range(len(down_block_types)): prev_base_output_channel = base_out_channels base_out_channels = reversed_base_block_out_channels[i] ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] self.up_connections.append( get_up_block_adapter( out_channels=base_out_channels, prev_output_channel=prev_base_output_channel, ctrl_skip_channels=ctrl_skip_channels_, ) ) @classmethod def from_unet( cls, unet: UNet2DConditionModel, size_ratio: Optional[float] = None, block_out_channels: Optional[List[int]] = None, num_attention_heads: Optional[List[int]] = None, learn_time_embedding: bool = False, time_embedding_mix: int = 1.0, conditioning_channels: int = 3, conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), ): r""" Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`]. Parameters: unet (`UNet2DConditionModel`): The UNet model we want to control. The dimensions of the ControlNetXSAdapter will be adapted to it. size_ratio (float, *optional*, defaults to `None`): When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this or `block_out_channels` must be given. block_out_channels (`List[int]`, *optional*, defaults to `None`): Down blocks output channels in control model. Either this or `size_ratio` must be given. num_attention_heads (`List[int]`, *optional*, defaults to `None`): The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. learn_time_embedding (`bool`, defaults to `False`): Whether the `ControlNetXSAdapter` should learn a time embedding. time_embedding_mix (`float`, defaults to 1.0): If 0, then only the control adapter's time embedding is used. If 1, then only the base unet's time embedding is used. Otherwise, both are combined. conditioning_channels (`int`, defaults to 3): Number of channels of conditioning input (e.g. an image) conditioning_channel_order (`str`, defaults to `"rgb"`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `controlnet_cond_embedding` layer. """ # Check input fixed_size = block_out_channels is not None relative_size = size_ratio is not None if not (fixed_size ^ relative_size): raise ValueError( "Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)." ) # Create model block_out_channels = block_out_channels or [int(b * size_ratio) for b in unet.config.block_out_channels] if num_attention_heads is None: # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. num_attention_heads = unet.config.attention_head_dim model = cls( conditioning_channels=conditioning_channels, conditioning_channel_order=conditioning_channel_order, conditioning_embedding_out_channels=conditioning_embedding_out_channels, time_embedding_mix=time_embedding_mix, learn_time_embedding=learn_time_embedding, num_attention_heads=num_attention_heads, block_out_channels=block_out_channels, base_block_out_channels=unet.config.block_out_channels, cross_attention_dim=unet.config.cross_attention_dim, down_block_types=unet.config.down_block_types, sample_size=unet.config.sample_size, transformer_layers_per_block=unet.config.transformer_layers_per_block, upcast_attention=unet.config.upcast_attention, max_norm_num_groups=unet.config.norm_num_groups, use_linear_projection=unet.config.use_linear_projection, ) # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel model.to(unet.dtype) return model def forward(self, *args, **kwargs): raise ValueError( "A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel." ) class UNetControlNetXSModel(ModelMixin, ConfigMixin): r""" A UNet fused with a ControlNet-XS adapter model This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). `UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are compatible with StableDiffusion. It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in `ControlNetXSAdapter` . See their documentation for details. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, # unet configs sample_size: Optional[int] = 96, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), block_out_channels: Tuple[int] = (320, 640, 1280, 1280), norm_num_groups: Optional[int] = 32, cross_attention_dim: Union[int, Tuple[int]] = 1024, transformer_layers_per_block: Union[int, Tuple[int]] = 1, num_attention_heads: Union[int, Tuple[int]] = 8, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, upcast_attention: bool = True, use_linear_projection: bool = True, time_cond_proj_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None, # additional controlnet configs time_embedding_mix: float = 1.0, ctrl_conditioning_channels: int = 3, ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), ctrl_conditioning_channel_order: str = "rgb", ctrl_learn_time_embedding: bool = False, ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16), ctrl_num_attention_heads: Union[int, Tuple[int]] = 4, ctrl_max_norm_num_groups: int = 32, ): super().__init__() if time_embedding_mix < 0 or time_embedding_mix > 1: raise ValueError("`time_embedding_mix` needs to be between 0 and 1.") if time_embedding_mix < 1 and not ctrl_learn_time_embedding: raise ValueError("To use `time_embedding_mix` < 1, `ctrl_learn_time_embedding` must be `True`") if addition_embed_type is not None and addition_embed_type != "text_time": raise ValueError( "As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`." ) if not isinstance(transformer_layers_per_block, (list, tuple)): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if not isinstance(cross_attention_dim, (list, tuple)): cross_attention_dim = [cross_attention_dim] * len(down_block_types) if not isinstance(num_attention_heads, (list, tuple)): num_attention_heads = [num_attention_heads] * len(down_block_types) if not isinstance(ctrl_num_attention_heads, (list, tuple)): ctrl_num_attention_heads = [ctrl_num_attention_heads] * len(down_block_types) base_num_attention_heads = num_attention_heads self.in_channels = 4 # # Input self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=ctrl_block_out_channels[0], block_out_channels=ctrl_conditioning_embedding_out_channels, conditioning_channels=ctrl_conditioning_channels, ) self.ctrl_conv_in = nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1) self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0]) # # Time time_embed_input_dim = block_out_channels[0] time_embed_dim = block_out_channels[0] * 4 self.base_time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) self.base_time_embedding = TimestepEmbedding( time_embed_input_dim, time_embed_dim, cond_proj_dim=time_cond_proj_dim, ) if ctrl_learn_time_embedding: self.ctrl_time_embedding = TimestepEmbedding( in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim ) else: self.ctrl_time_embedding = None if addition_embed_type is None: self.base_add_time_proj = None self.base_add_embedding = None else: self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) # # Create down blocks down_blocks = [] base_out_channels = block_out_channels[0] ctrl_out_channels = ctrl_block_out_channels[0] for i, down_block_type in enumerate(down_block_types): base_in_channels = base_out_channels base_out_channels = block_out_channels[i] ctrl_in_channels = ctrl_out_channels ctrl_out_channels = ctrl_block_out_channels[i] has_crossattn = "CrossAttn" in down_block_type is_final_block = i == len(down_block_types) - 1 down_blocks.append( ControlNetXSCrossAttnDownBlock2D( base_in_channels=base_in_channels, base_out_channels=base_out_channels, ctrl_in_channels=ctrl_in_channels, ctrl_out_channels=ctrl_out_channels, temb_channels=time_embed_dim, norm_num_groups=norm_num_groups, ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, has_crossattn=has_crossattn, transformer_layers_per_block=transformer_layers_per_block[i], base_num_attention_heads=base_num_attention_heads[i], ctrl_num_attention_heads=ctrl_num_attention_heads[i], cross_attention_dim=cross_attention_dim[i], add_downsample=not is_final_block, upcast_attention=upcast_attention, use_linear_projection=use_linear_projection, ) ) # # Create mid block self.mid_block = ControlNetXSCrossAttnMidBlock2D( base_channels=block_out_channels[-1], ctrl_channels=ctrl_block_out_channels[-1], temb_channels=time_embed_dim, norm_num_groups=norm_num_groups, ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, transformer_layers_per_block=transformer_layers_per_block[-1], base_num_attention_heads=base_num_attention_heads[-1], ctrl_num_attention_heads=ctrl_num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], upcast_attention=upcast_attention, use_linear_projection=use_linear_projection, ) # # Create up blocks up_blocks = [] rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) rev_num_attention_heads = list(reversed(base_num_attention_heads)) rev_cross_attention_dim = list(reversed(cross_attention_dim)) # The skip connection channels are the output of the conv_in and of all the down subblocks ctrl_skip_channels = [ctrl_block_out_channels[0]] for i, out_channels in enumerate(ctrl_block_out_channels): number_of_subblocks = ( 3 if i < len(ctrl_block_out_channels) - 1 else 2 ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler ctrl_skip_channels.extend([out_channels] * number_of_subblocks) reversed_block_out_channels = list(reversed(block_out_channels)) out_channels = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = out_channels out_channels = reversed_block_out_channels[i] in_channels = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] has_crossattn = "CrossAttn" in up_block_type is_final_block = i == len(block_out_channels) - 1 up_blocks.append( ControlNetXSCrossAttnUpBlock2D( in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, ctrl_skip_channels=ctrl_skip_channels_, temb_channels=time_embed_dim, resolution_idx=i, has_crossattn=has_crossattn, transformer_layers_per_block=rev_transformer_layers_per_block[i], num_attention_heads=rev_num_attention_heads[i], cross_attention_dim=rev_cross_attention_dim[i], add_upsample=not is_final_block, upcast_attention=upcast_attention, norm_num_groups=norm_num_groups, use_linear_projection=use_linear_projection, ) ) self.down_blocks = nn.ModuleList(down_blocks) self.up_blocks = nn.ModuleList(up_blocks) self.base_conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups) self.base_conv_act = nn.SiLU() self.base_conv_out = nn.Conv2d(block_out_channels[0], 4, kernel_size=3, padding=1) @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet: Optional[ControlNetXSAdapter] = None, size_ratio: Optional[float] = None, ctrl_block_out_channels: Optional[List[float]] = None, time_embedding_mix: Optional[float] = None, ctrl_optional_kwargs: Optional[Dict] = None, ): r""" Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAdapter`] . Parameters: unet (`UNet2DConditionModel`): The UNet model we want to control. controlnet (`ControlNetXSAdapter`): The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS adapter will be created. size_ratio (float, *optional*, defaults to `None`): Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`): Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details, where this parameter is called `block_out_channels`. time_embedding_mix (`float`, *optional*, defaults to None): Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`): Passed to the `init` of the new controlent if no controlent was given. """ if controlnet is None: controlnet = ControlNetXSAdapter.from_unet( unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs ) else: if any( o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs) ): raise ValueError( "When a controlnet is passed, none of these parameters should be passed: size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs." ) # # get params params_for_unet = [ "sample_size", "down_block_types", "up_block_types", "block_out_channels", "norm_num_groups", "cross_attention_dim", "transformer_layers_per_block", "addition_embed_type", "addition_time_embed_dim", "upcast_attention", "use_linear_projection", "time_cond_proj_dim", "projection_class_embeddings_input_dim", ] params_for_unet = {k: v for k, v in unet.config.items() if k in params_for_unet} # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. params_for_unet["num_attention_heads"] = unet.config.attention_head_dim params_for_controlnet = [ "conditioning_channels", "conditioning_embedding_out_channels", "conditioning_channel_order", "learn_time_embedding", "block_out_channels", "num_attention_heads", "max_norm_num_groups", ] params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet} params_for_controlnet["time_embedding_mix"] = controlnet.config.time_embedding_mix # # create model model = cls.from_config({**params_for_unet, **params_for_controlnet}) # # load weights # from unet modules_from_unet = [ "time_embedding", "conv_in", "conv_norm_out", "conv_out", ] for m in modules_from_unet: getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) optional_modules_from_unet = [ "add_time_proj", "add_embedding", ] for m in optional_modules_from_unet: if hasattr(unet, m) and getattr(unet, m) is not None: getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) # from controlnet model.controlnet_cond_embedding.load_state_dict(controlnet.controlnet_cond_embedding.state_dict()) model.ctrl_conv_in.load_state_dict(controlnet.conv_in.state_dict()) if controlnet.time_embedding is not None: model.ctrl_time_embedding.load_state_dict(controlnet.time_embedding.state_dict()) model.control_to_base_for_conv_in.load_state_dict(controlnet.control_to_base_for_conv_in.state_dict()) # from both model.down_blocks = nn.ModuleList( ControlNetXSCrossAttnDownBlock2D.from_modules(b, c) for b, c in zip(unet.down_blocks, controlnet.down_blocks) ) model.mid_block = ControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet.mid_block) model.up_blocks = nn.ModuleList( ControlNetXSCrossAttnUpBlock2D.from_modules(b, c) for b, c in zip(unet.up_blocks, controlnet.up_connections) ) # ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel model.to(unet.dtype) return model def freeze_unet_params(self) -> None: """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" # Freeze everything for param in self.parameters(): param.requires_grad = True # Unfreeze ControlNetXSAdapter base_parts = [ "base_time_proj", "base_time_embedding", "base_add_time_proj", "base_add_embedding", "base_conv_in", "base_conv_norm_out", "base_conv_act", "base_conv_out", ] base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None] for part in base_parts: for param in part.parameters(): param.requires_grad = False for d in self.down_blocks: d.freeze_base_params() self.mid_block.freeze_base_params() for u in self.up_blocks: u.freeze_base_params() def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stage blocks where they are being applied. Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. Args: s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ for i, upsample_block in enumerate(self.up_blocks): setattr(upsample_block, "s1", s1) setattr(upsample_block, "s2", s2) setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu def disable_freeu(self): """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): for k in freeu_keys: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def forward( self, sample: Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: Optional[torch.Tensor] = None, conditioning_scale: Optional[float] = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, return_dict: bool = True, apply_control: bool = True, ) -> Union[ControlNetXSOutput, Tuple]: """ The [`ControlNetXSModel`] forward method. Args: sample (`Tensor`): The noisy input tensor. timestep (`Union[torch.Tensor, float, int]`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. controlnet_cond (`Tensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): How much the control model affects the base model outputs. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep embeddings. attention_mask (`torch.Tensor`, *optional*, defaults to `None`): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. added_cond_kwargs (`dict`): Additional conditions for the Stable Diffusion XL UNet. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. apply_control (`bool`, defaults to `True`): If `False`, the input is run only through the base model. Returns: [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ # check channel order if self.config.ctrl_conditioning_channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.base_time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) if self.config.ctrl_learn_time_embedding and apply_control: ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond) base_temb = self.base_time_embedding(t_emb, timestep_cond) interpolation_param = self.config.time_embedding_mix**0.3 temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) else: temb = self.base_time_embedding(t_emb) # added time & text embeddings aug_emb = None if self.config.addition_embed_type is None: pass elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.base_add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(temb.dtype) aug_emb = self.base_add_embedding(add_embeds) else: raise ValueError( f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config.addition_embed_type} is currently not supported." ) temb = temb + aug_emb if aug_emb is not None else temb # text embeddings cemb = encoder_hidden_states # Preparation h_ctrl = h_base = sample hs_base, hs_ctrl = [], [] # Cross Control guided_hint = self.controlnet_cond_embedding(controlnet_cond) # 1 - conv in & down h_base = self.base_conv_in(h_base) h_ctrl = self.ctrl_conv_in(h_ctrl) if guided_hint is not None: h_ctrl += guided_hint if apply_control: h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base hs_base.append(h_base) hs_ctrl.append(h_ctrl) for down in self.down_blocks: h_base, h_ctrl, residual_hb, residual_hc = down( hidden_states_base=h_base, hidden_states_ctrl=h_ctrl, temb=temb, encoder_hidden_states=cemb, conditioning_scale=conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, apply_control=apply_control, ) hs_base.extend(residual_hb) hs_ctrl.extend(residual_hc) # 2 - mid h_base, h_ctrl = self.mid_block( hidden_states_base=h_base, hidden_states_ctrl=h_ctrl, temb=temb, encoder_hidden_states=cemb, conditioning_scale=conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, apply_control=apply_control, ) # 3 - up for up in self.up_blocks: n_resnets = len(up.resnets) skips_hb = hs_base[-n_resnets:] skips_hc = hs_ctrl[-n_resnets:] hs_base = hs_base[:-n_resnets] hs_ctrl = hs_ctrl[:-n_resnets] h_base = up( hidden_states=h_base, res_hidden_states_tuple_base=skips_hb, res_hidden_states_tuple_ctrl=skips_hc, temb=temb, encoder_hidden_states=cemb, conditioning_scale=conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, apply_control=apply_control, ) # 4 - conv out h_base = self.base_conv_norm_out(h_base) h_base = self.base_conv_act(h_base) h_base = self.base_conv_out(h_base) if not return_dict: return (h_base,) return ControlNetXSOutput(sample=h_base) class ControlNetXSCrossAttnDownBlock2D(nn.Module): def __init__( self, base_in_channels: int, base_out_channels: int, ctrl_in_channels: int, ctrl_out_channels: int, temb_channels: int, norm_num_groups: int = 32, ctrl_max_norm_num_groups: int = 32, has_crossattn=True, transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, base_num_attention_heads: Optional[int] = 1, ctrl_num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, add_downsample: bool = True, upcast_attention: Optional[bool] = False, use_linear_projection: Optional[bool] = True, ): super().__init__() base_resnets = [] base_attentions = [] ctrl_resnets = [] ctrl_attentions = [] ctrl_to_base = [] base_to_ctrl = [] num_layers = 2 # only support sd + sdxl if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): base_in_channels = base_in_channels if i == 0 else base_out_channels ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels # Before the resnet/attention application, information is concatted from base to control. # Concat doesn't require change in number of channels base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) base_resnets.append( ResnetBlock2D( in_channels=base_in_channels, out_channels=base_out_channels, temb_channels=temb_channels, groups=norm_num_groups, ) ) ctrl_resnets.append( ResnetBlock2D( in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl out_channels=ctrl_out_channels, temb_channels=temb_channels, groups=find_largest_factor( ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups ), groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), eps=1e-5, ) ) if has_crossattn: base_attentions.append( Transformer2DModel( base_num_attention_heads, base_out_channels // base_num_attention_heads, in_channels=base_out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, norm_num_groups=norm_num_groups, ) ) ctrl_attentions.append( Transformer2DModel( ctrl_num_attention_heads, ctrl_out_channels // ctrl_num_attention_heads, in_channels=ctrl_out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), ) ) # After the resnet/attention application, information is added from control to base # Addition requires change in number of channels ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) if add_downsample: # Before the downsampler application, information is concatted from base to control # Concat doesn't require change in number of channels base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) self.base_downsamplers = Downsample2D( base_out_channels, use_conv=True, out_channels=base_out_channels, name="op" ) self.ctrl_downsamplers = Downsample2D( ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" ) # After the downsampler application, information is added from control to base # Addition requires change in number of channels ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) else: self.base_downsamplers = None self.ctrl_downsamplers = None self.base_resnets = nn.ModuleList(base_resnets) self.ctrl_resnets = nn.ModuleList(ctrl_resnets) self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None] * num_layers self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None] * num_layers self.base_to_ctrl = nn.ModuleList(base_to_ctrl) self.ctrl_to_base = nn.ModuleList(ctrl_to_base) self.gradient_checkpointing = False @classmethod def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: DownBlockControlNetXSAdapter): # get params def get_first_cross_attention(block): return block.attentions[0].transformer_blocks[0].attn2 base_in_channels = base_downblock.resnets[0].in_channels base_out_channels = base_downblock.resnets[0].out_channels ctrl_in_channels = ( ctrl_downblock.resnets[0].in_channels - base_in_channels ) # base channels are concatted to ctrl channels in init ctrl_out_channels = ctrl_downblock.resnets[0].out_channels temb_channels = base_downblock.resnets[0].time_emb_proj.in_features num_groups = base_downblock.resnets[0].norm1.num_groups ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups if hasattr(base_downblock, "attentions"): has_crossattn = True transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks) base_num_attention_heads = get_first_cross_attention(base_downblock).heads ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim upcast_attention = get_first_cross_attention(base_downblock).upcast_attention use_linear_projection = base_downblock.attentions[0].use_linear_projection else: has_crossattn = False transformer_layers_per_block = None base_num_attention_heads = None ctrl_num_attention_heads = None cross_attention_dim = None upcast_attention = None use_linear_projection = None add_downsample = base_downblock.downsamplers is not None # create model model = cls( base_in_channels=base_in_channels, base_out_channels=base_out_channels, ctrl_in_channels=ctrl_in_channels, ctrl_out_channels=ctrl_out_channels, temb_channels=temb_channels, norm_num_groups=num_groups, ctrl_max_norm_num_groups=ctrl_num_groups, has_crossattn=has_crossattn, transformer_layers_per_block=transformer_layers_per_block, base_num_attention_heads=base_num_attention_heads, ctrl_num_attention_heads=ctrl_num_attention_heads, cross_attention_dim=cross_attention_dim, add_downsample=add_downsample, upcast_attention=upcast_attention, use_linear_projection=use_linear_projection, ) # # load weights model.base_resnets.load_state_dict(base_downblock.resnets.state_dict()) model.ctrl_resnets.load_state_dict(ctrl_downblock.resnets.state_dict()) if has_crossattn: model.base_attentions.load_state_dict(base_downblock.attentions.state_dict()) model.ctrl_attentions.load_state_dict(ctrl_downblock.attentions.state_dict()) if add_downsample: model.base_downsamplers.load_state_dict(base_downblock.downsamplers[0].state_dict()) model.ctrl_downsamplers.load_state_dict(ctrl_downblock.downsamplers.state_dict()) model.base_to_ctrl.load_state_dict(ctrl_downblock.base_to_ctrl.state_dict()) model.ctrl_to_base.load_state_dict(ctrl_downblock.ctrl_to_base.state_dict()) return model def freeze_base_params(self) -> None: """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" # Unfreeze everything for param in self.parameters(): param.requires_grad = True # Freeze base part base_parts = [self.base_resnets] if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones base_parts.append(self.base_attentions) if self.base_downsamplers is not None: base_parts.append(self.base_downsamplers) for part in base_parts: for param in part.parameters(): param.requires_grad = False def forward( self, hidden_states_base: Tensor, temb: Tensor, encoder_hidden_states: Optional[Tensor] = None, hidden_states_ctrl: Optional[Tensor] = None, conditioning_scale: Optional[float] = 1.0, attention_mask: Optional[Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[Tensor] = None, apply_control: bool = True, ) -> Tuple[Tensor, Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...]]: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") h_base = hidden_states_base h_ctrl = hidden_states_ctrl base_output_states = () ctrl_output_states = () base_blocks = list(zip(self.base_resnets, self.base_attentions)) ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base ): # concat base -> ctrl if apply_control: h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # apply base subblock if self.training and self.gradient_checkpointing: ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} h_base = torch.utils.checkpoint.checkpoint( create_custom_forward(b_res), h_base, temb, **ckpt_kwargs, ) else: h_base = b_res(h_base, temb) if b_attn is not None: h_base = b_attn( h_base, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] # apply ctrl subblock if apply_control: if self.training and self.gradient_checkpointing: ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} h_ctrl = torch.utils.checkpoint.checkpoint( create_custom_forward(c_res), h_ctrl, temb, **ckpt_kwargs, ) else: h_ctrl = c_res(h_ctrl, temb) if c_attn is not None: h_ctrl = c_attn( h_ctrl, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] # add ctrl -> base if apply_control: h_base = h_base + c2b(h_ctrl) * conditioning_scale base_output_states = base_output_states + (h_base,) ctrl_output_states = ctrl_output_states + (h_ctrl,) if self.base_downsamplers is not None: # if we have a base_downsampler, then also a ctrl_downsampler b2c = self.base_to_ctrl[-1] c2b = self.ctrl_to_base[-1] # concat base -> ctrl if apply_control: h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # apply base subblock h_base = self.base_downsamplers(h_base) # apply ctrl subblock if apply_control: h_ctrl = self.ctrl_downsamplers(h_ctrl) # add ctrl -> base if apply_control: h_base = h_base + c2b(h_ctrl) * conditioning_scale base_output_states = base_output_states + (h_base,) ctrl_output_states = ctrl_output_states + (h_ctrl,) return h_base, h_ctrl, base_output_states, ctrl_output_states class ControlNetXSCrossAttnMidBlock2D(nn.Module): def __init__( self, base_channels: int, ctrl_channels: int, temb_channels: Optional[int] = None, norm_num_groups: int = 32, ctrl_max_norm_num_groups: int = 32, transformer_layers_per_block: int = 1, base_num_attention_heads: Optional[int] = 1, ctrl_num_attention_heads: Optional[int] = 1, cross_attention_dim: Optional[int] = 1024, upcast_attention: bool = False, use_linear_projection: Optional[bool] = True, ): super().__init__() # Before the midblock application, information is concatted from base to control. # Concat doesn't require change in number of channels self.base_to_ctrl = make_zero_conv(base_channels, base_channels) self.base_midblock = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block, in_channels=base_channels, temb_channels=temb_channels, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=base_num_attention_heads, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) self.ctrl_midblock = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block, in_channels=ctrl_channels + base_channels, out_channels=ctrl_channels, temb_channels=temb_channels, # number or norm groups must divide both in_channels and out_channels resnet_groups=find_largest_factor( gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups ), cross_attention_dim=cross_attention_dim, num_attention_heads=ctrl_num_attention_heads, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) # After the midblock application, information is added from control to base # Addition requires change in number of channels self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) self.gradient_checkpointing = False @classmethod def from_modules( cls, base_midblock: UNetMidBlock2DCrossAttn, ctrl_midblock: MidBlockControlNetXSAdapter, ): base_to_ctrl = ctrl_midblock.base_to_ctrl ctrl_to_base = ctrl_midblock.ctrl_to_base ctrl_midblock = ctrl_midblock.midblock # get params def get_first_cross_attention(midblock): return midblock.attentions[0].transformer_blocks[0].attn2 base_channels = ctrl_to_base.out_channels ctrl_channels = ctrl_to_base.in_channels transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks) temb_channels = base_midblock.resnets[0].time_emb_proj.in_features num_groups = base_midblock.resnets[0].norm1.num_groups ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups base_num_attention_heads = get_first_cross_attention(base_midblock).heads ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim upcast_attention = get_first_cross_attention(base_midblock).upcast_attention use_linear_projection = base_midblock.attentions[0].use_linear_projection # create model model = cls( base_channels=base_channels, ctrl_channels=ctrl_channels, temb_channels=temb_channels, norm_num_groups=num_groups, ctrl_max_norm_num_groups=ctrl_num_groups, transformer_layers_per_block=transformer_layers_per_block, base_num_attention_heads=base_num_attention_heads, ctrl_num_attention_heads=ctrl_num_attention_heads, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_linear_projection=use_linear_projection, ) # load weights model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict()) model.base_midblock.load_state_dict(base_midblock.state_dict()) model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict()) model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict()) return model def freeze_base_params(self) -> None: """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" # Unfreeze everything for param in self.parameters(): param.requires_grad = True # Freeze base part for param in self.base_midblock.parameters(): param.requires_grad = False def forward( self, hidden_states_base: Tensor, temb: Tensor, encoder_hidden_states: Tensor, hidden_states_ctrl: Optional[Tensor] = None, conditioning_scale: Optional[float] = 1.0, cross_attention_kwargs: Optional[Dict[str, Any]] = None, attention_mask: Optional[Tensor] = None, encoder_attention_mask: Optional[Tensor] = None, apply_control: bool = True, ) -> Tuple[Tensor, Tensor]: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") h_base = hidden_states_base h_ctrl = hidden_states_ctrl joint_args = { "temb": temb, "encoder_hidden_states": encoder_hidden_states, "attention_mask": attention_mask, "cross_attention_kwargs": cross_attention_kwargs, "encoder_attention_mask": encoder_attention_mask, } if apply_control: h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl h_base = self.base_midblock(h_base, **joint_args) # apply base mid block if apply_control: h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base return h_base, h_ctrl class ControlNetXSCrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, ctrl_skip_channels: List[int], temb_channels: int, norm_num_groups: int = 32, resolution_idx: Optional[int] = None, has_crossattn=True, transformer_layers_per_block: int = 1, num_attention_heads: int = 1, cross_attention_dim: int = 1024, add_upsample: bool = True, upcast_attention: bool = False, use_linear_projection: Optional[bool] = True, ): super().__init__() resnets = [] attentions = [] ctrl_to_base = [] num_layers = 3 # only support sd + sdxl self.has_cross_attention = has_crossattn self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, groups=norm_num_groups, ) ) if has_crossattn: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, norm_num_groups=norm_num_groups, ) ) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers self.ctrl_to_base = nn.ModuleList(ctrl_to_base) if add_upsample: self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx @classmethod def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter): ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base # get params def get_first_cross_attention(block): return block.attentions[0].transformer_blocks[0].attn2 out_channels = base_upblock.resnets[0].out_channels in_channels = base_upblock.resnets[-1].in_channels - out_channels prev_output_channels = base_upblock.resnets[0].in_channels - out_channels ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections] temb_channels = base_upblock.resnets[0].time_emb_proj.in_features num_groups = base_upblock.resnets[0].norm1.num_groups resolution_idx = base_upblock.resolution_idx if hasattr(base_upblock, "attentions"): has_crossattn = True transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks) num_attention_heads = get_first_cross_attention(base_upblock).heads cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim upcast_attention = get_first_cross_attention(base_upblock).upcast_attention use_linear_projection = base_upblock.attentions[0].use_linear_projection else: has_crossattn = False transformer_layers_per_block = None num_attention_heads = None cross_attention_dim = None upcast_attention = None use_linear_projection = None add_upsample = base_upblock.upsamplers is not None # create model model = cls( in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channels, ctrl_skip_channels=ctrl_skip_channelss, temb_channels=temb_channels, norm_num_groups=num_groups, resolution_idx=resolution_idx, has_crossattn=has_crossattn, transformer_layers_per_block=transformer_layers_per_block, num_attention_heads=num_attention_heads, cross_attention_dim=cross_attention_dim, add_upsample=add_upsample, upcast_attention=upcast_attention, use_linear_projection=use_linear_projection, ) # load weights model.resnets.load_state_dict(base_upblock.resnets.state_dict()) if has_crossattn: model.attentions.load_state_dict(base_upblock.attentions.state_dict()) if add_upsample: model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict()) model.ctrl_to_base.load_state_dict(ctrl_to_base_skip_connections.state_dict()) return model def freeze_base_params(self) -> None: """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine tuning.""" # Unfreeze everything for param in self.parameters(): param.requires_grad = True # Freeze base part base_parts = [self.resnets] if isinstance(self.attentions, nn.ModuleList): # attentions can be a list of Nones base_parts.append(self.attentions) if self.upsamplers is not None: base_parts.append(self.upsamplers) for part in base_parts: for param in part.parameters(): param.requires_grad = False def forward( self, hidden_states: Tensor, res_hidden_states_tuple_base: Tuple[Tensor, ...], res_hidden_states_tuple_ctrl: Tuple[Tensor, ...], temb: Tensor, encoder_hidden_states: Optional[Tensor] = None, conditioning_scale: Optional[float] = 1.0, cross_attention_kwargs: Optional[Dict[str, Any]] = None, attention_mask: Optional[Tensor] = None, upsample_size: Optional[int] = None, encoder_attention_mask: Optional[Tensor] = None, apply_control: bool = True, ) -> Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) and getattr(self, "b1", None) and getattr(self, "b2", None) ) def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): # FreeU: Only operate on the first two stages if is_freeu_enabled: return apply_freeu( self.resolution_idx, hidden_states, res_h_base, s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2, ) else: return hidden_states, res_h_base for resnet, attn, c2b, res_h_base, res_h_ctrl in zip( self.resnets, self.attentions, self.ctrl_to_base, reversed(res_hidden_states_tuple_base), reversed(res_hidden_states_tuple_ctrl), ): if apply_control: hidden_states += c2b(res_h_ctrl) * conditioning_scale hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) hidden_states = torch.cat([hidden_states, res_h_base], dim=1) if self.training and self.gradient_checkpointing: ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) else: hidden_states = resnet(hidden_states, temb) if attn is not None: hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if self.upsamplers is not None: hidden_states = self.upsamplers(hidden_states, upsample_size) return hidden_states def make_zero_conv(in_channels, out_channels=None): return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module def find_largest_factor(number, max_factor): factor = max_factor if factor >= number: return number while factor != 0: residual = number % factor if residual == 0: return factor factor -= 1 ================================================ FILE: src/diffusers/models/downsampling.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from ..utils import deprecate from .normalization import RMSNorm from .upsampling import upfirdn2d_native class Downsample1D(nn.Module): """A 1D downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. name (`str`, default `conv`): name of the downsampling 1D layer. """ def __init__( self, channels: int, use_conv: bool = False, out_channels: Optional[int] = None, padding: int = 1, name: str = "conv", ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) def forward(self, inputs: torch.Tensor) -> torch.Tensor: assert inputs.shape[1] == self.channels return self.conv(inputs) class Downsample2D(nn.Module): """A 2D downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. name (`str`, default `conv`): name of the downsampling 2D layer. """ def __init__( self, channels: int, use_conv: bool = False, out_channels: Optional[int] = None, padding: int = 1, name: str = "conv", kernel_size=3, norm_type=None, eps=None, elementwise_affine=None, bias=True, ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine) elif norm_type == "rms_norm": self.norm = RMSNorm(channels, eps, elementwise_affine) elif norm_type is None: self.norm = None else: raise ValueError(f"unknown norm_type: {norm_type}") if use_conv: conv = nn.Conv2d( self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.Conv2d_0 = conv self.conv = conv elif name == "Conv2d_0": self.conv = conv else: self.conv = conv def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) assert hidden_states.shape[1] == self.channels if self.norm is not None: hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels hidden_states = self.conv(hidden_states) return hidden_states class FirDownsample2D(nn.Module): """A 2D FIR downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. fir_kernel (`tuple`, default `(1, 3, 3, 1)`): kernel for the FIR filter. """ def __init__( self, channels: Optional[int] = None, out_channels: Optional[int] = None, use_conv: bool = False, fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), ): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.fir_kernel = fir_kernel self.use_conv = use_conv self.out_channels = out_channels def _downsample_2d( self, hidden_states: torch.Tensor, weight: Optional[torch.Tensor] = None, kernel: Optional[torch.Tensor] = None, factor: int = 2, gain: float = 1, ) -> torch.Tensor: """Fused `Conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: hidden_states (`torch.Tensor`): Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. weight (`torch.Tensor`, *optional*): Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. kernel (`torch.Tensor`, *optional*): FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor (`int`, *optional*, default to `2`): Integer downsampling factor. gain (`float`, *optional*, default to `1.0`): Scaling factor for signal magnitude. Returns: output (`torch.Tensor`): Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor # setup kernel kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * gain if self.use_conv: _, _, convH, convW = weight.shape pad_value = (kernel.shape[0] - factor) + (convW - 1) stride_value = [factor, factor] upfirdn_input = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), pad=((pad_value + 1) // 2, pad_value // 2), ) output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) else: pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) return output def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.use_conv: downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) return hidden_states # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead class KDownsample2D(nn.Module): r"""A 2D K-downsampling layer. Parameters: pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. """ def __init__(self, pad_mode: str = "reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) def forward(self, inputs: torch.Tensor) -> torch.Tensor: inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) weight = inputs.new_zeros( [ inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1], ] ) indices = torch.arange(inputs.shape[1], device=inputs.device) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel return F.conv2d(inputs, weight, stride=2) class CogVideoXDownsample3D(nn.Module): # Todo: Wait for paper relase. r""" A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI Args: in_channels (`int`): Number of channels in the input image. out_channels (`int`): Number of channels produced by the convolution. kernel_size (`int`, defaults to `3`): Size of the convolving kernel. stride (`int`, defaults to `2`): Stride of the convolution. padding (`int`, defaults to `0`): Padding added to all four sides of the input. compress_time (`bool`, defaults to `False`): Whether or not to compress the time dimension. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 2, padding: int = 0, compress_time: bool = False, ): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.compress_time = compress_time def forward(self, x: torch.Tensor) -> torch.Tensor: if self.compress_time: batch_size, channels, frames, height, width = x.shape # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames) x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames) if x.shape[-1] % 2 == 1: x_first, x_rest = x[..., 0], x[..., 1:] if x_rest.shape[-1] > 0: # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2) x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) x = torch.cat([x_first[..., None], x_rest], dim=-1) # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width) x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) else: # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2) x = F.avg_pool1d(x, kernel_size=2, stride=2) # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) # Pad the tensor pad = (0, 1, 0, 1) x = F.pad(x, pad, mode="constant", value=0) batch_size, channels, frames, height, width = x.shape # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) x = self.conv(x) # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) return x def downsample_2d( hidden_states: torch.Tensor, kernel: Optional[torch.Tensor] = None, factor: int = 2, gain: float = 1, ) -> torch.Tensor: r"""Downsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a multiple of the downsampling factor. Args: hidden_states (`torch.Tensor`) Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel (`torch.Tensor`, *optional*): FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor (`int`, *optional*, default to `2`): Integer downsampling factor. gain (`float`, *optional*, default to `1.0`): Scaling factor for signal magnitude. Returns: output (`torch.Tensor`): Tensor of the shape `[N, C, H // factor, W // factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * gain pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) return output ================================================ FILE: src/diffusers/models/embeddings.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from ..utils import deprecate from .activations import FP32SiLU, get_activation from .attention_processor import Attention def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def get_3d_sincos_pos_embed( embed_dim: int, spatial_size: Union[int, Tuple[int, int]], temporal_size: int, spatial_interpolation_scale: float = 1.0, temporal_interpolation_scale: float = 1.0, ) -> np.ndarray: r""" Args: embed_dim (`int`): spatial_size (`int` or `Tuple[int, int]`): temporal_size (`int`): spatial_interpolation_scale (`float`, defaults to 1.0): temporal_interpolation_scale (`float`, defaults to 1.0): """ if embed_dim % 4 != 0: raise ValueError("`embed_dim` must be divisible by 4") if isinstance(spatial_size, int): spatial_size = (spatial_size, spatial_size) embed_dim_spatial = 3 * embed_dim // 4 embed_dim_temporal = embed_dim // 4 # 1. Spatial grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) # 2. Temporal grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) # 3. Concat pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] return pos_embed def get_2d_sincos_pos_embed( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if isinstance(grid_size, int): grid_size = (grid_size, grid_size) grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class PatchEmbed(nn.Module): """2D Image to Patch Embedding with support for SD3 cropping.""" def __init__( self, height=224, width=224, patch_size=16, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, interpolation_scale=1, pos_embed_type="sincos", pos_embed_max_size=None, # For SD3 cropping ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.pos_embed_max_size = pos_embed_max_size self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None self.patch_size = patch_size self.height, self.width = height // patch_size, width // patch_size self.base_size = height // patch_size self.interpolation_scale = interpolation_scale # Calculate positional embeddings based on max size or default if pos_embed_max_size: grid_size = pos_embed_max_size else: grid_size = int(num_patches**0.5) if pos_embed_type is None: self.pos_embed = None elif pos_embed_type == "sincos": pos_embed = get_2d_sincos_pos_embed( embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale ) persistent = True if pos_embed_max_size else False self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) else: raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") def cropped_pos_embed(self, height, width): """Crops positional embeddings for SD3 compatibility.""" if self.pos_embed_max_size is None: raise ValueError("`pos_embed_max_size` must be set for cropping.") height = height // self.patch_size width = width // self.patch_size if height > self.pos_embed_max_size: raise ValueError( f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." ) if width > self.pos_embed_max_size: raise ValueError( f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." ) top = (self.pos_embed_max_size - height) // 2 left = (self.pos_embed_max_size - width) // 2 spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed def forward(self, latent): if self.pos_embed_max_size is not None: height, width = latent.shape[-2:] else: height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size latent = self.proj(latent) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: latent = self.norm(latent) if self.pos_embed is None: return latent.to(latent.dtype) # Interpolate or crop positional embeddings as needed if self.pos_embed_max_size: pos_embed = self.cropped_pos_embed(height, width) else: if self.height != height or self.width != width: pos_embed = get_2d_sincos_pos_embed( embed_dim=self.pos_embed.shape[-1], grid_size=(height, width), base_size=self.base_size, interpolation_scale=self.interpolation_scale, ) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) else: pos_embed = self.pos_embed return (latent + pos_embed).to(latent.dtype) class LuminaPatchEmbed(nn.Module): """2D Image to Patch Embedding with support for Lumina-T2X""" def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True): super().__init__() self.patch_size = patch_size self.proj = nn.Linear( in_features=patch_size * patch_size * in_channels, out_features=embed_dim, bias=bias, ) def forward(self, x, freqs_cis): """ Patchifies and embeds the input tensor(s). Args: x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded. Returns: Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the frequency tensor(s). """ freqs_cis = freqs_cis.to(x[0].device) patch_height = patch_width = self.patch_size batch_size, channel, height, width = x.size() height_tokens, width_tokens = height // patch_height, width // patch_width x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute( 0, 2, 4, 1, 3, 5 ) x = x.flatten(3) x = self.proj(x) x = x.flatten(1, 2) mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device) return ( x, mask, [(height, width)] * batch_size, freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0), ) class CogVideoXPatchEmbed(nn.Module): def __init__( self, patch_size: int = 2, in_channels: int = 16, embed_dim: int = 1920, text_embed_dim: int = 4096, bias: bool = True, ) -> None: super().__init__() self.patch_size = patch_size self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) self.text_proj = nn.Linear(text_embed_dim, embed_dim) def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): r""" Args: text_embeds (`torch.Tensor`): Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). image_embeds (`torch.Tensor`): Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). """ text_embeds = self.text_proj(text_embeds) batch, num_frames, channels, height, width = image_embeds.shape image_embeds = image_embeds.reshape(-1, channels, height, width) image_embeds = self.proj(image_embeds) image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] embeds = torch.cat( [text_embeds, image_embeds], dim=1 ).contiguous() # [batch, seq_length + num_frames x height x width, channels] return embeds def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. Args: embed_dim: (`int`): The embedding dimension size crops_coords (`Tuple[int]`) The top-left and bottom-right coordinates of the crop. grid_size (`Tuple[int]`): The grid size of the positional embedding. use_real (`bool`): If True, return real part and imaginary part separately. Otherwise, return complex numbers. Returns: `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. """ start, stop = crops_coords grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) # [2, W, H] grid = grid.reshape([2, 1, *grid.shape[1:]]) pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) return pos_embed def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): assert embed_dim % 4 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_rotary_pos_embed( embed_dim // 2, grid[0].reshape(-1), use_real=use_real ) # (H*W, D/2) if use_real else (H*W, D/4) emb_w = get_1d_rotary_pos_embed( embed_dim // 2, grid[1].reshape(-1), use_real=use_real ) # (H*W, D/2) if use_real else (H*W, D/4) if use_real: cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) return cos, sin else: emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) return emb def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0): assert embed_dim % 4 == 0 emb_h = get_1d_rotary_pos_embed( embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor ) # (H, D/4) emb_w = get_1d_rotary_pos_embed( embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor ) # (W, D/4) emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1) emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1) emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2) return emb def get_1d_rotary_pos_embed( dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (`int`): Dimension of the frequency tensor. pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar theta (`float`, *optional*, defaults to 10000.0): Scaling factor for frequency computation. Defaults to 10000.0. use_real (`bool`, *optional*): If True, return real part and imaginary part separately. Otherwise, return complex numbers. linear_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the context extrapolation. Defaults to 1.0. ntk_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. repeat_interleave_real (`bool`, *optional*, defaults to `True`): If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. Otherwise, they are concateanted with themselves. Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ assert dim % 2 == 0 if isinstance(pos, int): pos = np.arange(pos) theta = theta * ntk_factor freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] if use_real and repeat_interleave_real: freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] return freqs_cos, freqs_sin elif use_real: freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] return freqs_cos, freqs_sin else: freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis def apply_rotary_emb( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. Args: x (`torch.Tensor`): Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ if use_real: cos, sin = freqs_cis # [S, D] cos = cos[None, None] sin = sin[None, None] cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: # Use for example in Lumina x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: # Use for example in Stable Audio x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out else: x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) return x_out.type_as(x) class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) if post_act_fn is None: self.post_act = None else: self.post_act = get_activation(post_act_fn) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift self.scale = scale def forward(self, timesteps): t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, scale=self.scale, ) return t_emb class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__( self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.log = log self.flip_sin_to_cos = flip_sin_to_cos if set_W_to_weight: # to delete later del self.weight self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.weight = self.W del self.W def forward(self, x): if self.log: x = torch.log(x) x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi if self.flip_sin_to_cos: out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) else: out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return out class SinusoidalPositionalEmbedding(nn.Module): """Apply positional information to a sequence of embeddings. Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to them Args: embed_dim: (int): Dimension of the positional embedding. max_seq_length: Maximum sequence length to apply positional embeddings """ def __init__(self, embed_dim: int, max_seq_length: int = 32): super().__init__() position = torch.arange(max_seq_length).unsqueeze(1) div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) pe = torch.zeros(1, max_seq_length, embed_dim) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) def forward(self, x): _, seq_length, _ = x.shape x = x + self.pe[:, :seq_length] return x class ImagePositionalEmbeddings(nn.Module): """ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the height and width of the latent space. For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 For VQ-diffusion: Output vector embeddings are used as input for the transformer. Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. Args: num_embed (`int`): Number of embeddings for the latent pixels embeddings. height (`int`): Height of the latent image i.e. the number of height embeddings. width (`int`): Width of the latent image i.e. the number of width embeddings. embed_dim (`int`): Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. """ def __init__( self, num_embed: int, height: int, width: int, embed_dim: int, ): super().__init__() self.height = height self.width = width self.num_embed = num_embed self.embed_dim = embed_dim self.emb = nn.Embedding(self.num_embed, embed_dim) self.height_emb = nn.Embedding(self.height, embed_dim) self.width_emb = nn.Embedding(self.width, embed_dim) def forward(self, index): emb = self.emb(index) height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) # 1 x H x D -> 1 x H x 1 x D height_emb = height_emb.unsqueeze(2) width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) # 1 x W x D -> 1 x 1 x W x D width_emb = width_emb.unsqueeze(1) pos_emb = height_emb + width_emb # 1 x H x W x D -> 1 x L xD pos_emb = pos_emb.view(1, self.height * self.width, -1) emb = emb + pos_emb[:, : emb.shape[1], :] return emb class LabelEmbedding(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. Args: num_classes (`int`): The number of classes. hidden_size (`int`): The size of the vector embeddings. dropout_prob (`float`): The probability of dropping a label. """ def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) self.num_classes = num_classes self.dropout_prob = dropout_prob def token_drop(self, labels, force_drop_ids=None): """ Drops labels to enable classifier-free guidance. """ if force_drop_ids is None: drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob else: drop_ids = torch.tensor(force_drop_ids == 1) labels = torch.where(drop_ids, self.num_classes, labels) return labels def forward(self, labels: torch.LongTensor, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (self.training and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings class TextImageProjection(nn.Module): def __init__( self, text_embed_dim: int = 1024, image_embed_dim: int = 768, cross_attention_dim: int = 768, num_image_text_embeds: int = 10, ): super().__init__() self.num_image_text_embeds = num_image_text_embeds self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): batch_size = text_embeds.shape[0] # image image_text_embeds = self.image_embeds(image_embeds) image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) # text text_embeds = self.text_proj(text_embeds) return torch.cat([image_text_embeds, text_embeds], dim=1) class ImageProjection(nn.Module): def __init__( self, image_embed_dim: int = 768, cross_attention_dim: int = 768, num_image_text_embeds: int = 32, ): super().__init__() self.num_image_text_embeds = num_image_text_embeds self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.Tensor): batch_size = image_embeds.shape[0] # image image_embeds = self.image_embeds(image_embeds) image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) image_embeds = self.norm(image_embeds) return image_embeds class IPAdapterFullImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): super().__init__() from .attention import FeedForward self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.Tensor): return self.norm(self.ff(image_embeds)) class IPAdapterFaceIDImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): super().__init__() from .attention import FeedForward self.num_tokens = num_tokens self.cross_attention_dim = cross_attention_dim self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.Tensor): x = self.ff(image_embeds) x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) return self.norm(x) class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) def forward(self, timestep, class_labels, hidden_dtype=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) class_labels = self.class_embedder(class_labels) # (N, D) conditioning = timesteps_emb + class_labels # (N, D) return conditioning class CombinedTimestepTextProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) pooled_projections = self.text_embedder(pooled_projection) conditioning = timesteps_emb + pooled_projections return conditioning class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, guidance, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) guidance_proj = self.time_proj(guidance) guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) time_guidance_emb = timesteps_emb + guidance_emb pooled_projections = self.text_embedder(pooled_projection) conditioning = time_guidance_emb + pooled_projections return conditioning class HunyuanDiTAttentionPool(nn.Module): # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = x.permute(1, 0, 2) # NLC -> LNC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC x, _ = F.multi_head_attention_forward( query=x[:1], key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False, ) return x.squeeze(0) class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): def __init__( self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048, use_style_cond_and_image_meta_size=True, ): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.pooler = HunyuanDiTAttentionPool( seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim ) # Here we use a default learned embedder layer for future extension. self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size if use_style_cond_and_image_meta_size: self.style_embedder = nn.Embedding(1, embedding_dim) extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim else: extra_in_dim = pooled_projection_dim self.extra_embedder = PixArtAlphaTextProjection( in_features=extra_in_dim, hidden_size=embedding_dim * 4, out_features=embedding_dim, act_fn="silu_fp32", ) def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) # extra condition1: text pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) if self.use_style_cond_and_image_meta_size: # extra condition2: image meta size embedding image_meta_size = self.size_proj(image_meta_size.view(-1)) image_meta_size = image_meta_size.to(dtype=hidden_dtype) image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) # extra condition3: style embedding style_embedding = self.style_embedder(style) # (N, embedding_dim) # Concatenate all extra vectors extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) else: extra_cond = torch.cat([pooled_projections], dim=1) conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] return conditioning class LuminaCombinedTimestepCaptionEmbedding(nn.Module): def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256): super().__init__() self.time_proj = Timesteps( num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 ) self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) self.caption_embedder = nn.Sequential( nn.LayerNorm(cross_attention_dim), nn.Linear( cross_attention_dim, hidden_size, bias=True, ), ) def forward(self, timestep, caption_feat, caption_mask): # timestep embedding: time_freq = self.time_proj(timestep) time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) # caption condition embedding: caption_mask_float = caption_mask.float().unsqueeze(-1) caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1) caption_feats_pool = caption_feats_pool.to(caption_feat) caption_embed = self.caption_embedder(caption_feats_pool) conditioning = time_embed + caption_embed return conditioning class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() self.norm1 = nn.LayerNorm(encoder_dim) self.pool = AttentionPooling(num_heads, encoder_dim) self.proj = nn.Linear(encoder_dim, time_embed_dim) self.norm2 = nn.LayerNorm(time_embed_dim) def forward(self, hidden_states): hidden_states = self.norm1(hidden_states) hidden_states = self.pool(hidden_states) hidden_states = self.proj(hidden_states) hidden_states = self.norm2(hidden_states) return hidden_states class TextImageTimeEmbedding(nn.Module): def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) self.text_norm = nn.LayerNorm(time_embed_dim) self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): # text time_text_embeds = self.text_proj(text_embeds) time_text_embeds = self.text_norm(time_text_embeds) # image time_image_embeds = self.image_proj(image_embeds) return time_image_embeds + time_text_embeds class ImageTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_norm = nn.LayerNorm(time_embed_dim) def forward(self, image_embeds: torch.Tensor): # image time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_norm(time_image_embeds) return time_image_embeds class ImageHintTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_norm = nn.LayerNorm(time_embed_dim) self.input_hint_block = nn.Sequential( nn.Conv2d(3, 16, 3, padding=1), nn.SiLU(), nn.Conv2d(16, 16, 3, padding=1), nn.SiLU(), nn.Conv2d(16, 32, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(32, 32, 3, padding=1), nn.SiLU(), nn.Conv2d(32, 96, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(96, 96, 3, padding=1), nn.SiLU(), nn.Conv2d(96, 256, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(256, 4, 3, padding=1), ) def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): # image time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_norm(time_image_embeds) hint = self.input_hint_block(hint) return time_image_embeds, hint class AttentionPooling(nn.Module): # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 def __init__(self, num_heads, embed_dim, dtype=None): super().__init__() self.dtype = dtype self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.num_heads = num_heads self.dim_per_head = embed_dim // self.num_heads def forward(self, x): bs, length, width = x.size() def shape(x): # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, -1, self.num_heads, self.dim_per_head) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) x = x.transpose(1, 2) return x class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) # (bs*n_heads, class_token_length, dim_per_head) q = shape(self.q_proj(class_token)) # (bs*n_heads, length+class_token_length, dim_per_head) k = shape(self.k_proj(x)) v = shape(self.v_proj(x)) # (bs*n_heads, class_token_length, length+class_token_length): scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) # (bs*n_heads, dim_per_head, class_token_length) a = torch.einsum("bts,bcs->bct", weight, v) # (bs, length+1, width) a = a.reshape(bs, -1, 1).transpose(1, 2) return a[:, 0, :] # cls_token def get_fourier_embeds_from_boundingbox(embed_dim, box): """ Args: embed_dim: int box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline Returns: [B x N x embed_dim] tensor of positional embeddings """ batch_size, num_boxes = box.shape[:2] emb = 100 ** (torch.arange(embed_dim) / embed_dim) emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) emb = emb * box.unsqueeze(-1) emb = torch.stack((emb.sin(), emb.cos()), dim=-1) emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) return emb class GLIGENTextBoundingboxProjection(nn.Module): def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): super().__init__() self.positive_len = positive_len self.out_dim = out_dim self.fourier_embedder_dim = fourier_freqs self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy if isinstance(out_dim, tuple): out_dim = out_dim[0] if feature_type == "text-only": self.linears = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) elif feature_type == "text-image": self.linears_text = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.linears_image = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) def forward( self, boxes, masks, positive_embeddings=None, phrases_masks=None, image_masks=None, phrases_embeddings=None, image_embeddings=None, ): masks = masks.unsqueeze(-1) # embedding position (it may includes padding as placeholder) xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C # learnable null embedding xyxy_null = self.null_position_feature.view(1, 1, -1) # replace padding with learnable null embedding xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null # positionet with text only information if positive_embeddings is not None: # learnable null embedding positive_null = self.null_positive_feature.view(1, 1, -1) # replace padding with learnable null embedding positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) # positionet with text and image information else: phrases_masks = phrases_masks.unsqueeze(-1) image_masks = image_masks.unsqueeze(-1) # learnable null embedding text_null = self.null_text_feature.view(1, 1, -1) image_null = self.null_image_feature.view(1, 1, -1) # replace padding with learnable null embedding phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) objs = torch.cat([objs_text, objs_image], dim=1) return objs class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. Reference: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.use_additional_conditions = use_additional_conditions if use_additional_conditions: self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) if self.use_additional_conditions: resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) else: conditioning = timesteps_emb return conditioning class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): super().__init__() if out_features is None: out_features = hidden_size self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) if act_fn == "gelu_tanh": self.act_1 = nn.GELU(approximate="tanh") elif act_fn == "silu": self.act_1 = nn.SiLU() elif act_fn == "silu_fp32": self.act_1 = FP32SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) def forward(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states class IPAdapterPlusImageProjectionBlock(nn.Module): def __init__( self, embed_dims: int = 768, dim_head: int = 64, heads: int = 16, ffn_ratio: float = 4, ) -> None: super().__init__() from .attention import FeedForward self.ln0 = nn.LayerNorm(embed_dims) self.ln1 = nn.LayerNorm(embed_dims) self.attn = Attention( query_dim=embed_dims, dim_head=dim_head, heads=heads, out_bias=False, ) self.ff = nn.Sequential( nn.LayerNorm(embed_dims), FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), ) def forward(self, x, latents, residual): encoder_hidden_states = self.ln0(x) latents = self.ln1(latents) encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) latents = self.attn(latents, encoder_hidden_states) + residual latents = self.ff(latents) + latents return latents class IPAdapterPlusImageProjection(nn.Module): """Resampler of IP-Adapter Plus. Args: embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, that is the same number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. Defaults to 16. num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio of feedforward network hidden layer channels. Defaults to 4. """ def __init__( self, embed_dims: int = 768, output_dims: int = 1024, hidden_dims: int = 1280, depth: int = 4, dim_head: int = 64, heads: int = 16, num_queries: int = 8, ffn_ratio: float = 4, ) -> None: super().__init__() self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) self.proj_in = nn.Linear(embed_dims, hidden_dims) self.proj_out = nn.Linear(hidden_dims, output_dims) self.norm_out = nn.LayerNorm(output_dims) self.layers = nn.ModuleList( [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Args: x (torch.Tensor): Input Tensor. Returns: torch.Tensor: Output Tensor. """ latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) for block in self.layers: residual = latents latents = block(x, latents, residual) latents = self.proj_out(latents) return self.norm_out(latents) class IPAdapterFaceIDPlusImageProjection(nn.Module): """FacePerceiverResampler of IP-Adapter Plus. Args: embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, that is the same number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio of feedforward network hidden layer channels. Defaults to 4. ffproj_ratio (float): The expansion ratio of feedforward network hidden layer channels (for ID embeddings). Defaults to 4. """ def __init__( self, embed_dims: int = 768, output_dims: int = 768, hidden_dims: int = 1280, id_embeddings_dim: int = 512, depth: int = 4, dim_head: int = 64, heads: int = 16, num_tokens: int = 4, num_queries: int = 8, ffn_ratio: float = 4, ffproj_ratio: int = 2, ) -> None: super().__init__() from .attention import FeedForward self.num_tokens = num_tokens self.embed_dim = embed_dims self.clip_embeds = None self.shortcut = False self.shortcut_scale = 1.0 self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) self.norm = nn.LayerNorm(embed_dims) self.proj_in = nn.Linear(hidden_dims, embed_dims) self.proj_out = nn.Linear(embed_dims, output_dims) self.norm_out = nn.LayerNorm(output_dims) self.layers = nn.ModuleList( [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] ) def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: """Forward pass. Args: id_embeds (torch.Tensor): Input Tensor (ID embeds). Returns: torch.Tensor: Output Tensor. """ id_embeds = id_embeds.to(self.clip_embeds.dtype) id_embeds = self.proj(id_embeds) id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) id_embeds = self.norm(id_embeds) latents = id_embeds clip_embeds = self.proj_in(self.clip_embeds) x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) for block in self.layers: residual = latents latents = block(x, latents, residual) latents = self.proj_out(latents) out = self.norm_out(latents) if self.shortcut: out = id_embeds + self.shortcut_scale * out return out class MultiIPAdapterImageProjection(nn.Module): def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): super().__init__() self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) def forward(self, image_embeds: List[torch.Tensor]): projected_image_embeds = [] # currently, we accept `image_embeds` as # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] if not isinstance(image_embeds, list): deprecation_message = ( "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning." ) deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) image_embeds = [image_embeds.unsqueeze(1)] if len(image_embeds) != len(self.image_projection_layers): raise ValueError( f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" ) for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): batch_size, num_images = image_embed.shape[0], image_embed.shape[1] image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) image_embed = image_projection_layer(image_embed) image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) projected_image_embeds.append(image_embed) return projected_image_embeds ================================================ FILE: src/diffusers/models/embeddings_flax.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import flax.linen as nn import jax.numpy as jnp def get_sinusoidal_embeddings( timesteps: jnp.ndarray, embedding_dim: int, freq_shift: float = 1, min_timescale: float = 1, max_timescale: float = 1.0e4, flip_sin_to_cos: bool = False, scale: float = 1.0, ) -> jnp.ndarray: """Returns the positional encoding (same as Tensor2Tensor). Args: timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim: The number of output channels. min_timescale: The smallest time unit (should probably be 0.0). max_timescale: The largest time unit. Returns: a Tensor of timing signals [N, num_channels] """ assert timesteps.ndim == 1, "Timesteps should be a 1d-array" assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" num_timescales = float(embedding_dim // 2) log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) # scale embeddings scaled_time = scale * emb if flip_sin_to_cos: signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) else: signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) return signal class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. Args: time_embed_dim (`int`, *optional*, defaults to `32`): Time step embedding dimension dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ time_embed_dim: int = 32 dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, temb): temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) temb = nn.silu(temb) temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) return temb class FlaxTimesteps(nn.Module): r""" Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 Args: dim (`int`, *optional*, defaults to `32`): Time step embedding dimension """ dim: int = 32 flip_sin_to_cos: bool = False freq_shift: float = 1 @nn.compact def __call__(self, timesteps): return get_sinusoidal_embeddings( timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift ) ================================================ FILE: src/diffusers/models/lora.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # IMPORTANT: # ################################################################### # ----------------------------------------------------------------# # This file is deprecated and will be removed soon # # (as soon as PEFT will become a required dependency for LoRA) # # ----------------------------------------------------------------# ################################################################### from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from ..utils import deprecate, logging from ..utils.import_utils import is_transformers_available if is_transformers_available(): from transformers import CLIPTextModel, CLIPTextModelWithProjection logger = logging.get_logger(__name__) # pylint: disable=invalid-name def text_encoder_attn_modules(text_encoder): attn_modules = [] if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): for i, layer in enumerate(text_encoder.text_model.encoder.layers): name = f"text_model.encoder.layers.{i}.self_attn" mod = layer.self_attn attn_modules.append((name, mod)) else: raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") return attn_modules def text_encoder_mlp_modules(text_encoder): mlp_modules = [] if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): for i, layer in enumerate(text_encoder.text_model.encoder.layers): mlp_mod = layer.mlp name = f"text_model.encoder.layers.{i}.mlp" mlp_modules.append((name, mlp_mod)) else: raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}") return mlp_modules def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): attn_module.q_proj.lora_scale = lora_scale attn_module.k_proj.lora_scale = lora_scale attn_module.v_proj.lora_scale = lora_scale attn_module.out_proj.lora_scale = lora_scale for _, mlp_module in text_encoder_mlp_modules(text_encoder): if isinstance(mlp_module.fc1, PatchedLoraProjection): mlp_module.fc1.lora_scale = lora_scale mlp_module.fc2.lora_scale = lora_scale class PatchedLoraProjection(torch.nn.Module): def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." deprecate("PatchedLoraProjection", "1.0.0", deprecation_message) super().__init__() from ..models.lora import LoRALinearLayer self.regular_linear_layer = regular_linear_layer device = self.regular_linear_layer.weight.device if dtype is None: dtype = self.regular_linear_layer.weight.dtype self.lora_linear_layer = LoRALinearLayer( self.regular_linear_layer.in_features, self.regular_linear_layer.out_features, network_alpha=network_alpha, device=device, dtype=dtype, rank=rank, ) self.lora_scale = lora_scale # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved # when saving the whole text encoder model and when LoRA is unloaded or fused def state_dict(self, *args, destination=None, prefix="", keep_vars=False): if self.lora_linear_layer is None: return self.regular_linear_layer.state_dict( *args, destination=destination, prefix=prefix, keep_vars=keep_vars ) return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): if self.lora_linear_layer is None: return dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device w_orig = self.regular_linear_layer.weight.data.float() w_up = self.lora_linear_layer.up.weight.data.float() w_down = self.lora_linear_layer.down.weight.data.float() if self.lora_linear_layer.network_alpha is not None: w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) if safe_fusing and torch.isnan(fused_weight).any().item(): raise ValueError( "This LoRA weight seems to be broken. " f"Encountered NaN values when trying to fuse LoRA weights for {self}." "LoRA weights will not be fused." ) self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now self.lora_linear_layer = None # offload the up and down matrices to CPU to not blow the memory self.w_up = w_up.cpu() self.w_down = w_down.cpu() self.lora_scale = lora_scale def _unfuse_lora(self): if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): return fused_weight = self.regular_linear_layer.weight.data dtype, device = fused_weight.dtype, fused_weight.device w_up = self.w_up.to(device=device).float() w_down = self.w_down.to(device).float() unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None self.w_down = None def forward(self, input): if self.lora_scale is None: self.lora_scale = 1.0 if self.lora_linear_layer is None: return self.regular_linear_layer(input) return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input)) class LoRALinearLayer(nn.Module): r""" A linear layer that is used with LoRA. Parameters: in_features (`int`): Number of input features. out_features (`int`): Number of output features. rank (`int`, `optional`, defaults to 4): The rank of the LoRA layer. network_alpha (`float`, `optional`, defaults to `None`): The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning device (`torch.device`, `optional`, defaults to `None`): The device to use for the layer's weights. dtype (`torch.dtype`, `optional`, defaults to `None`): The dtype to use for the layer's weights. """ def __init__( self, in_features: int, out_features: int, rank: int = 4, network_alpha: Optional[float] = None, device: Optional[Union[torch.device, str]] = None, dtype: Optional[torch.dtype] = None, ): super().__init__() deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." deprecate("LoRALinearLayer", "1.0.0", deprecation_message) self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning self.network_alpha = network_alpha self.rank = rank self.out_features = out_features self.in_features = in_features nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) if self.network_alpha is not None: up_hidden_states *= self.network_alpha / self.rank return up_hidden_states.to(orig_dtype) class LoRAConv2dLayer(nn.Module): r""" A convolutional layer that is used with LoRA. Parameters: in_features (`int`): Number of input features. out_features (`int`): Number of output features. rank (`int`, `optional`, defaults to 4): The rank of the LoRA layer. kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1): The kernel size of the convolution. stride (`int` or `tuple` of two `int`, `optional`, defaults to 1): The stride of the convolution. padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0): The padding of the convolution. network_alpha (`float`, `optional`, defaults to `None`): The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning """ def __init__( self, in_features: int, out_features: int, rank: int = 4, kernel_size: Union[int, Tuple[int, int]] = (1, 1), stride: Union[int, Tuple[int, int]] = (1, 1), padding: Union[int, Tuple[int, int], str] = 0, network_alpha: Optional[float] = None, ): super().__init__() deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message) self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) # according to the official kohya_ss trainer kernel_size are always fixed for the up layer # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning self.network_alpha = network_alpha self.rank = rank nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) if self.network_alpha is not None: up_hidden_states *= self.network_alpha / self.rank return up_hidden_states.to(orig_dtype) class LoRACompatibleConv(nn.Conv2d): """ A convolutional layer that can be used with LoRA. """ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs): deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." deprecate("LoRACompatibleConv", "1.0.0", deprecation_message) super().__init__(*args, **kwargs) self.lora_layer = lora_layer def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." deprecate("set_lora_layer", "1.0.0", deprecation_message) self.lora_layer = lora_layer def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): if self.lora_layer is None: return dtype, device = self.weight.data.dtype, self.weight.data.device w_orig = self.weight.data.float() w_up = self.lora_layer.up.weight.data.float() w_down = self.lora_layer.down.weight.data.float() if self.lora_layer.network_alpha is not None: w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) fusion = fusion.reshape((w_orig.shape)) fused_weight = w_orig + (lora_scale * fusion) if safe_fusing and torch.isnan(fused_weight).any().item(): raise ValueError( "This LoRA weight seems to be broken. " f"Encountered NaN values when trying to fuse LoRA weights for {self}." "LoRA weights will not be fused." ) self.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now self.lora_layer = None # offload the up and down matrices to CPU to not blow the memory self.w_up = w_up.cpu() self.w_down = w_down.cpu() self._lora_scale = lora_scale def _unfuse_lora(self): if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): return fused_weight = self.weight.data dtype, device = fused_weight.data.dtype, fused_weight.data.device self.w_up = self.w_up.to(device=device).float() self.w_down = self.w_down.to(device).float() fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) fusion = fusion.reshape((fused_weight.shape)) unfused_weight = fused_weight.float() - (self._lora_scale * fusion) self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None self.w_down = None def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: if self.padding_mode != "zeros": hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode) padding = (0, 0) else: padding = self.padding original_outputs = F.conv2d( hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups ) if self.lora_layer is None: return original_outputs else: return original_outputs + (scale * self.lora_layer(hidden_states)) class LoRACompatibleLinear(nn.Linear): """ A Linear layer that can be used with LoRA. """ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message) super().__init__(*args, **kwargs) self.lora_layer = lora_layer def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`." deprecate("set_lora_layer", "1.0.0", deprecation_message) self.lora_layer = lora_layer def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): if self.lora_layer is None: return dtype, device = self.weight.data.dtype, self.weight.data.device w_orig = self.weight.data.float() w_up = self.lora_layer.up.weight.data.float() w_down = self.lora_layer.down.weight.data.float() if self.lora_layer.network_alpha is not None: w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) if safe_fusing and torch.isnan(fused_weight).any().item(): raise ValueError( "This LoRA weight seems to be broken. " f"Encountered NaN values when trying to fuse LoRA weights for {self}." "LoRA weights will not be fused." ) self.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now self.lora_layer = None # offload the up and down matrices to CPU to not blow the memory self.w_up = w_up.cpu() self.w_down = w_down.cpu() self._lora_scale = lora_scale def _unfuse_lora(self): if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): return fused_weight = self.weight.data dtype, device = fused_weight.dtype, fused_weight.device w_up = self.w_up.to(device=device).float() w_down = self.w_down.to(device).float() unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None self.w_down = None def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: if self.lora_layer is None: out = super().forward(hidden_states) return out else: out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) return out ================================================ FILE: src/diffusers/models/model_loading_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect import os from collections import OrderedDict from pathlib import Path from typing import List, Optional, Union import safetensors import torch from huggingface_hub.utils import EntryNotFoundError from ..utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, WEIGHTS_INDEX_NAME, _add_variant, _get_model_file, is_accelerate_available, is_torch_version, logging, ) logger = logging.get_logger(__name__) _CLASS_REMAPPING_DICT = { "Transformer2DModel": { "ada_norm_zero": "DiTTransformer2DModel", "ada_norm_single": "PixArtTransformer2DModel", } } if is_accelerate_available(): from accelerate import infer_auto_device_map from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device # Adapted from `transformers` (see modeling_utils.py) def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): if isinstance(device_map, str): no_split_modules = model._get_no_split_modules(device_map) device_map_kwargs = {"no_split_module_classes": no_split_modules} if device_map != "sequential": max_memory = get_balanced_memory( model, dtype=torch_dtype, low_zero=(device_map == "balanced_low_0"), max_memory=max_memory, **device_map_kwargs, ) else: max_memory = get_max_memory(max_memory) device_map_kwargs["max_memory"] = max_memory device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) return device_map def _fetch_remapped_cls_from_config(config, old_class): previous_class_name = old_class.__name__ remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None) # Details: # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818 if remapped_class_name: # load diffusers library to import compatible and original scheduler diffusers_library = importlib.import_module(__name__.split(".")[0]) remapped_class = getattr(diffusers_library, remapped_class_name) logger.info( f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type." f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this" " DOESN'T affect the final results." ) return remapped_class else: return old_class def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): """ Reads a checkpoint file, returning properly formatted errors if they arise. """ try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: return safetensors.torch.load_file(checkpoint_file, device="cpu") else: weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} return torch.load( checkpoint_file, map_location="cpu", **weights_only_kwarg, ) except Exception as e: try: with open(checkpoint_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please install " "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " "you cloned." ) else: raise ValueError( f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " "model. Make sure you have saved the model properly." ) from e except (UnicodeDecodeError, ValueError): raise OSError( f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " ) def load_model_dict_into_meta( model, state_dict: OrderedDict, device: Optional[Union[str, torch.device]] = None, dtype: Optional[Union[str, torch.dtype]] = None, model_name_or_path: Optional[str] = None, ) -> List[str]: device = device or torch.device("cpu") dtype = dtype or torch.float32 accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) unexpected_keys = [] empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): if param_name not in empty_state_dict: unexpected_keys.append(param_name) continue if empty_state_dict[param_name].shape != param.shape: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." ) if accepts_dtype: set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) else: set_module_tensor_to_device(model, param_name, device, value=param) return unexpected_keys def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it state_dict = state_dict.copy() error_msgs = [] # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module: torch.nn.Module, prefix: str = ""): args = (state_dict, prefix, {}, True, [], [], error_msgs) module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(model_to_load) return error_msgs def _fetch_index_file( is_local, pretrained_model_name_or_path, subfolder, use_safetensors, cache_dir, variant, force_download, proxies, local_files_only, token, revision, user_agent, commit_hash, ): if is_local: index_file = Path( pretrained_model_name_or_path, subfolder or "", _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), ) else: index_file_in_repo = Path( subfolder or "", _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), ).as_posix() try: index_file = _get_model_file( pretrained_model_name_or_path, weights_name=index_file_in_repo, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=None, user_agent=user_agent, commit_hash=commit_hash, ) index_file = Path(index_file) except (EntryNotFoundError, EnvironmentError): index_file = None return index_file ================================================ FILE: src/diffusers/models/modeling_flax_pytorch_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch - Flax general utilities.""" import re import jax.numpy as jnp from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey from ..utils import logging logger = logging.get_logger(__name__) def rename_key(key): regex = r"\w+[.]\d+" pats = re.findall(regex, key) for pat in pats: key = key.replace(pat, "_".join(pat.split("."))) return key ##################### # PyTorch => Flax # ##################### # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" # conv norm or layer norm renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) # rename attention layers if len(pt_tuple_key) > 1: for rename_from, rename_to in ( ("to_out_0", "proj_attn"), ("to_k", "key"), ("to_v", "value"), ("to_q", "query"), ): if pt_tuple_key[-2] == rename_from: weight_name = pt_tuple_key[-1] weight_name = "kernel" if weight_name == "weight" else weight_name renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name) if renamed_pt_tuple_key in random_flax_state_dict: assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape return renamed_pt_tuple_key, pt_tensor.T if ( any("norm" in str_ for str_ in pt_tuple_key) and (pt_tuple_key[-1] == "bias") and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) ): renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) return renamed_pt_tuple_key, pt_tensor elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) return renamed_pt_tuple_key, pt_tensor # embedding if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) return renamed_pt_tuple_key, pt_tensor # conv layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: pt_tensor = pt_tensor.transpose(2, 3, 1, 0) return renamed_pt_tuple_key, pt_tensor # linear layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) if pt_tuple_key[-1] == "weight": pt_tensor = pt_tensor.T return renamed_pt_tuple_key, pt_tensor # old PyTorch layer norm weight renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) if pt_tuple_key[-1] == "gamma": return renamed_pt_tuple_key, pt_tensor # old PyTorch layer norm bias renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) if pt_tuple_key[-1] == "beta": return renamed_pt_tuple_key, pt_tensor return pt_tuple_key, pt_tensor def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): # Step 1: Convert pytorch tensor to numpy pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} # Step 2: Since the model is stateless, get random Flax params random_flax_params = flax_model.init_weights(PRNGKey(init_key)) random_flax_state_dict = flatten_dict(random_flax_params) flax_state_dict = {} # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): renamed_pt_key = rename_key(pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) # Correctly rename weight parameters flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) if flax_key in random_flax_state_dict: if flax_tensor.shape != random_flax_state_dict[flax_key].shape: raise ValueError( f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." ) # also add unexpected weight so that warning is thrown flax_state_dict[flax_key] = jnp.asarray(flax_tensor) return unflatten_dict(flax_state_dict) ================================================ FILE: src/diffusers/models/modeling_flax_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from pickle import UnpicklingError from typing import Any, Dict, Union import jax import jax.numpy as jnp import msgpack.exceptions from flax.core.frozen_dict import FrozenDict, unfreeze from flax.serialization import from_bytes, to_bytes from flax.traverse_util import flatten_dict, unflatten_dict from huggingface_hub import create_repo, hf_hub_download from huggingface_hub.utils import ( EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, validate_hf_hub_args, ) from requests import HTTPError from .. import __version__, is_torch_available from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, PushToHubMixin, logging, ) from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax logger = logging.get_logger(__name__) class FlaxModelMixin(PushToHubMixin): r""" Base class for all Flax models. [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and saving models. - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`]. """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _flax_internal_args = ["name", "parent", "dtype"] @classmethod def _from_config(cls, config, **kwargs): """ All context managers that the model should be initialized under go here. """ return cls(config, **kwargs) def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: """ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. """ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 def conditional_cast(param): if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): param = param.astype(dtype) return param if mask is None: return jax.tree_map(conditional_cast, params) flat_params = flatten_dict(params) flat_mask, _ = jax.tree_flatten(mask) for masked, key in zip(flat_mask, flat_params.keys()): if masked: param = flat_params[key] flat_params[key] = conditional_cast(param) return unflatten_dict(flat_params) def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast the `params` in place. This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. Arguments: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` for params you want to cast, and `False` for those you want to skip. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # load model >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision >>> params = model.to_bf16(params) >>> # If you don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } >>> mask = traverse_util.unflatten_dict(mask) >>> params = model.to_bf16(params, mask) ```""" return self._cast_floating_to(params, jnp.bfloat16, mask) def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. Arguments: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` for params you want to cast, and `False` for those you want to skip. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model params will be in fp32, to illustrate the use of this method, >>> # we'll first cast to fp16 and back to fp32 >>> params = model.to_f16(params) >>> # now cast back to fp32 >>> params = model.to_fp32(params) ```""" return self._cast_floating_to(params, jnp.float32, mask) def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the `params` in place. This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full half-precision training or to save weights in float16 for inference in order to save memory and improve speed. Arguments: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` for params you want to cast, and `False` for those you want to skip. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # load model >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model params will be in fp32, to cast these to float16 >>> params = model.to_fp16(params) >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } >>> mask = traverse_util.unflatten_dict(mask) >>> params = model.to_fp16(params, mask) ```""" return self._cast_floating_to(params, jnp.float16, mask) def init_weights(self, rng: jax.Array) -> Dict: raise NotImplementedError(f"init_weights method has to be implemented for {self}") @classmethod @validate_hf_hub_args def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], dtype: jnp.dtype = jnp.float32, *model_args, **kwargs, ): r""" Instantiate a pretrained Flax model from a pretrained model configuration. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`): Can be either: - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved using [`~FlaxModelMixin.save_pretrained`]. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified, all the computation will be performed with the given `dtype`. This only specifies the dtype of the *computation* and does not influence the dtype of model parameters. If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and [`~FlaxModelMixin.to_bf16`]. model_args (sequence of positional arguments, *optional*): All remaining positional arguments are passed to the underlying model's `__init__` method. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. from_pt (`bool`, *optional*, defaults to `False`): Load the model weights from a PyTorch checkpoint save file. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it is loaded) and initiate the model (for example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying model's `__init__` method (we assume all relevant updates to the configuration have already been done). - If a configuration is not provided, `kwargs` are first passed to the configuration class initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds to a configuration attribute is used to override said attribute with the supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute are passed to the underlying model's `__init__` function. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co and cache. >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") ``` If you get the error message below, you need to finetune the weights for your downstream task: ```bash Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) from_pt = kwargs.pop("from_pt", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) user_agent = { "diffusers": __version__, "file_type": "model", "framework": "flax", } # Load config if we don't provide one if config is None: config, unused_kwargs = cls.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, **kwargs, ) model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs) # Load model pretrained_path_with_subfolder = ( pretrained_model_name_or_path if subfolder is None else os.path.join(pretrained_model_name_or_path, subfolder) ) if os.path.isdir(pretrained_path_with_subfolder): if from_pt: if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): raise EnvironmentError( f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} " ) model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME) elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME) # Check if pytorch weights exist instead elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): raise EnvironmentError( f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model" " using `from_pt=True`." ) else: raise EnvironmentError( f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " f"{pretrained_path_with_subfolder}." ) else: try: model_file = hf_hub_download( pretrained_model_name_or_path, filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, user_agent=user_agent, subfolder=subfolder, revision=revision, ) except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " "token having permission to this repo with `token` or log in with `huggingface-cli " "login`." ) except RevisionNotFoundError: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " "this model name. Check the model page at " f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}." ) except HTTPError as err: raise EnvironmentError( f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" f"{err}" ) except ValueError: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your" " internet connection or see how to run the library in offline mode at" " 'https://huggingface.co/docs/transformers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." ) if from_pt: if is_torch_available(): from .modeling_utils import load_state_dict else: raise EnvironmentError( "Can't load the model in PyTorch format because PyTorch is not installed. " "Please, install PyTorch or use native Flax weights." ) # Step 1: Get the pytorch file pytorch_model_file = load_state_dict(model_file) # Step 2: Convert the weights state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) else: try: with open(model_file, "rb") as state_f: state = from_bytes(cls, state_f.read()) except (UnpicklingError, msgpack.exceptions.ExtraData) as e: try: with open(model_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please" " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" " folder you cloned." ) else: raise ValueError from e except (UnicodeDecodeError, ValueError): raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") # make sure all arrays are stored as jnp.ndarray # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # https://github.com/google/flax/issues/1261 state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state) # flatten dicts state = flatten_dict(state) params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0)) required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) shape_state = flatten_dict(unfreeze(params_shape_tree)) missing_keys = required_params - set(state.keys()) unexpected_keys = set(state.keys()) - required_params if missing_keys: logger.warning( f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " "Make sure to call model.init_weights to initialize the missing weights." ) cls._missing_keys = missing_keys for key in state.keys(): if key in shape_state and state[key].shape != shape_state[key].shape: raise ValueError( f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " ) # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: del state[unexpected_key] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" " with another architecture." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" f" was trained on, you can already use {model.__class__.__name__} for predictions without further" " training." ) return model, unflatten_dict(state) def save_pretrained( self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict], is_main_process: bool = True, push_to_hub: bool = False, **kwargs, ): """ Save a model and its configuration file to a directory so that it can be reloaded using the [`~FlaxModelMixin.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save a model and its configuration file to. Will be created if it doesn't exist. params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). kwargs (`Dict[str, Any]`, *optional*): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", False) create_pr = kwargs.pop("create_pr", False) token = kwargs.pop("token", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id model_to_save = self # Attach architecture to the config # Save the config if is_main_process: model_to_save.save_config(save_directory) # save model output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) with open(output_model_file, "wb") as f: model_bytes = to_bytes(params) f.write(model_bytes) logger.info(f"Model weights saved in {output_model_file}") if push_to_hub: self._upload_folder( save_directory, repo_id, token=token, commit_message=commit_message, create_pr=create_pr, ) ================================================ FILE: src/diffusers/models/modeling_outputs.py ================================================ from dataclasses import dataclass from ..utils import BaseOutput @dataclass class AutoencoderKLOutput(BaseOutput): """ Output of AutoencoderKL encoding method. Args: latent_dist (`DiagonalGaussianDistribution`): Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. `DiagonalGaussianDistribution` allows for sampling latents from the distribution. """ latent_dist: "DiagonalGaussianDistribution" # noqa: F821 @dataclass class Transformer2DModelOutput(BaseOutput): """ The output of [`Transformer2DModel`]. Args: sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability distributions for the unnoised latent pixels. """ sample: "torch.Tensor" # noqa: F821 ================================================ FILE: src/diffusers/models/modeling_pytorch_flax_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch - Flax general utilities.""" from pickle import UnpicklingError import jax import jax.numpy as jnp import numpy as np from flax.serialization import from_bytes from flax.traverse_util import flatten_dict from ..utils import logging logger = logging.get_logger(__name__) ##################### # Flax => PyTorch # ##################### # from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352 def load_flax_checkpoint_in_pytorch_model(pt_model, model_file): try: with open(model_file, "rb") as flax_state_f: flax_state = from_bytes(None, flax_state_f.read()) except UnpicklingError as e: try: with open(model_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please" " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" " folder you cloned." ) else: raise ValueError from e except (UnicodeDecodeError, ValueError): raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") return load_flax_weights_in_pytorch_model(pt_model, flax_state) def load_flax_weights_in_pytorch_model(pt_model, flax_state): """Load flax checkpoints in a PyTorch model""" try: import torch # noqa: F401 except ImportError: logger.error( "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see" " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" " instructions." ) raise # check if we have bf16 weights is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() if any(is_type_bf16): # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16 # and bf16 is not fully supported in PT yet. logger.warning( "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " "before loading those in PyTorch model." ) flax_state = jax.tree_util.tree_map( lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state ) pt_model.base_model_prefix = "" flax_state_dict = flatten_dict(flax_state, sep=".") pt_model_dict = pt_model.state_dict() # keep track of unexpected & missing keys unexpected_keys = [] missing_keys = set(pt_model_dict.keys()) for flax_key_tuple, flax_tensor in flax_state_dict.items(): flax_key_tuple_array = flax_key_tuple.split(".") if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4: flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) elif flax_key_tuple_array[-1] == "kernel": flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] flax_tensor = flax_tensor.T elif flax_key_tuple_array[-1] == "scale": flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] if "time_embedding" not in flax_key_tuple_array: for i, flax_key_tuple_string in enumerate(flax_key_tuple_array): flax_key_tuple_array[i] = ( flax_key_tuple_string.replace("_0", ".0") .replace("_1", ".1") .replace("_2", ".2") .replace("_3", ".3") .replace("_4", ".4") .replace("_5", ".5") .replace("_6", ".6") .replace("_7", ".7") .replace("_8", ".8") .replace("_9", ".9") ) flax_key = ".".join(flax_key_tuple_array) if flax_key in pt_model_dict: if flax_tensor.shape != pt_model_dict[flax_key].shape: raise ValueError( f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." ) else: # add weight to pytorch dict flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) # remove from missing keys missing_keys.remove(flax_key) else: # weight is not expected by PyTorch model unexpected_keys.append(flax_key) pt_model.load_state_dict(pt_model_dict) # re-transform missing_keys to list missing_keys = list(missing_keys) if len(unexpected_keys) > 0: logger.warning( "Some weights of the Flax model were not used when initializing the PyTorch model" f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" " FlaxBertForSequenceClassification model)." ) if len(missing_keys) > 0: logger.warning( f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" " use it for predictions and inference." ) return pt_model ================================================ FILE: src/diffusers/models/modeling_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import itertools import json import os import re from collections import OrderedDict from functools import partial from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union import safetensors import torch from huggingface_hub import create_repo, split_torch_state_dict_into_shards from huggingface_hub.utils import validate_hf_hub_args from torch import Tensor, nn from .. import __version__ from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, _add_variant, _get_checkpoint_shard_files, _get_model_file, deprecate, is_accelerate_available, is_torch_version, logging, ) from ..utils.hub_utils import ( PushToHubMixin, load_or_create_model_card, populate_model_card, ) from .model_loading_utils import ( _determine_device_map, _fetch_index_file, _load_state_dict_into_model, load_model_dict_into_meta, load_state_dict, ) logger = logging.get_logger(__name__) _REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}") if is_torch_version(">=", "1.9.0"): _LOW_CPU_MEM_USAGE_DEFAULT = True else: _LOW_CPU_MEM_USAGE_DEFAULT = False if is_accelerate_available(): import accelerate def get_parameter_device(parameter: torch.nn.Module) -> torch.device: try: parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) return next(parameters_and_buffers).device except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = parameter._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].device def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: try: params = tuple(parameter.parameters()) if len(params) > 0: return params[0].dtype buffers = tuple(parameter.buffers()) if len(buffers) > 0: return buffers[0].dtype except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = parameter._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].dtype class ModelMixin(torch.nn.Module, PushToHubMixin): r""" Base class for all models. [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and saving models. - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None _no_split_modules = None def __init__(self): super().__init__() def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__': https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module """ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) is_attribute = name in self.__dict__ if is_in_config and not is_attribute: deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) return self._internal_dict[name] # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module return super().__getattr__(name) @property def is_gradient_checkpointing(self) -> bool: """ Whether gradient checkpointing is activated for this model or not. """ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) def enable_gradient_checkpointing(self) -> None: """ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or *checkpoint activations* in other frameworks). """ if not self._supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") self.apply(partial(self._set_gradient_checkpointing, value=True)) def disable_gradient_checkpointing(self) -> None: """ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or *checkpoint activations* in other frameworks). """ if self._supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) def set_use_npu_flash_attention(self, valid: bool) -> None: r""" Set the switch for the npu flash attention. """ def fn_recursive_set_npu_flash_attention(module: torch.nn.Module): if hasattr(module, "set_use_npu_flash_attention"): module.set_use_npu_flash_attention(valid) for child in module.children(): fn_recursive_set_npu_flash_attention(child) for module in self.children(): if isinstance(module, torch.nn.Module): fn_recursive_set_npu_flash_attention(module) def enable_npu_flash_attention(self) -> None: r""" Enable npu flash attention from torch_npu """ self.set_use_npu_flash_attention(True) def disable_npu_flash_attention(self) -> None: r""" disable npu flash attention from torch_npu """ self.set_use_npu_flash_attention(False) def set_use_memory_efficient_attention_xformers( self, valid: bool, attention_op: Optional[Callable] = None ) -> None: # Recursively walk through all the children. # Any children which exposes the set_use_memory_efficient_attention_xformers method # gets the message def fn_recursive_set_mem_eff(module: torch.nn.Module): if hasattr(module, "set_use_memory_efficient_attention_xformers"): module.set_use_memory_efficient_attention_xformers(valid, attention_op) for child in module.children(): fn_recursive_set_mem_eff(child) for module in self.children(): if isinstance(module, torch.nn.Module): fn_recursive_set_mem_eff(module) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None: r""" Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed. ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent. Parameters: attention_op (`Callable`, *optional*): Override the default `None` operator for use as `op` argument to the [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) function of xFormers. Examples: ```py >>> import torch >>> from diffusers import UNet2DConditionModel >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp >>> model = UNet2DConditionModel.from_pretrained( ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 ... ) >>> model = model.to("cuda") >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) ``` """ self.set_use_memory_efficient_attention_xformers(True, attention_op) def disable_xformers_memory_efficient_attention(self) -> None: r""" Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). """ self.set_use_memory_efficient_attention_xformers(False) def save_pretrained( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, save_function: Optional[Callable] = None, safe_serialization: bool = True, variant: Optional[str] = None, max_shard_size: Union[int, str] = "10GB", push_to_hub: bool = False, **kwargs, ): """ Save a model and its configuration file to a directory so that it can be reloaded using the [`~models.ModelMixin.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save a model and its configuration file to. Will be created if it doesn't exist. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. max_shard_size (`int` or `str`, defaults to `"10GB"`): The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`. This is to establish a common default size for this argument across different libraries in the Hugging Face ecosystem (`transformers`, and `accelerate`, for example). push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = _add_variant(weights_name, variant) weight_name_split = weights_name.split(".") if len(weight_name_split) in [2, 3]: weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:]) else: raise ValueError(f"Invalid {weights_name} provided.") os.makedirs(save_directory, exist_ok=True) if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", False) create_pr = kwargs.pop("create_pr", False) token = kwargs.pop("token", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id # Only save the model itself if we are using distributed training model_to_save = self # Attach architecture to the config # Save the config if is_main_process: model_to_save.save_config(save_directory) # Save the model state_dict = model_to_save.state_dict() # Save the model state_dict_split = split_torch_state_dict_into_shards( state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern ) # Clean the folder from a previous save if is_main_process: for filename in os.listdir(save_directory): if filename in state_dict_split.filename_to_tensors.keys(): continue full_filename = os.path.join(save_directory, filename) if not os.path.isfile(full_filename): continue weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "") weights_without_ext = weights_without_ext.replace("{suffix}", "") filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "") # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 if ( filename.startswith(weights_without_ext) and _REGEX_SHARD.fullmatch(filename_without_ext) is not None ): os.remove(full_filename) for filename, tensors in state_dict_split.filename_to_tensors.items(): shard = {tensor: state_dict[tensor] for tensor in tensors} filepath = os.path.join(save_directory, filename) if safe_serialization: # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) else: torch.save(shard, filepath) if state_dict_split.is_sharded: index = { "metadata": state_dict_split.metadata, "weight_map": state_dict_split.tensor_to_filename, } save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) # Save the index as well with open(save_index_file, "w", encoding="utf-8") as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) logger.info( f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " f"index located at {save_index_file}." ) else: path_to_weights = os.path.join(save_directory, weights_name) logger.info(f"Model weights saved in {path_to_weights}") if push_to_hub: # Create a new empty model card and eventually tag it model_card = load_or_create_model_card(repo_id, token=token) model_card = populate_model_card(model_card) model_card.save(Path(save_directory, "README.md").as_posix()) self._upload_folder( save_directory, repo_id, token=token, commit_message=commit_message, create_pr=create_pr, ) @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a pretrained PyTorch model from a pretrained model configuration. The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To train the model, set it back in training mode with `model.train()`. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`~ModelMixin.save_pretrained`]. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info (`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. from_flax (`bool`, *optional*, defaults to `False`): Load the model weights from a Flax checkpoint save file. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Defaults to `None`, meaning that the model will be loaded on CPU. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if `device_map` contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. variant (`str`, *optional*): Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. You can also activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a firewalled environment. Example: ```py from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") ``` If you get the error message below, you need to finetune the weights for your downstream task: ```bash Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ cache_dir = kwargs.pop("cache_dir", None) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" " install accelerate\n```\n." ) if device_map is not None and not is_accelerate_available(): raise NotImplementedError( "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" " `device_map=None`. You can install accelerate with `pip install accelerate`." ) # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." ) if low_cpu_mem_usage is False and device_map is not None: raise ValueError( f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) # change device_map into a map if we passed an int, a str or a torch.device if isinstance(device_map, torch.device): device_map = {"": device_map} elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: try: device_map = {"": torch.device(device_map)} except RuntimeError: raise ValueError( "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." ) elif isinstance(device_map, int): if device_map < 0: raise ValueError( "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " ) else: device_map = {"": device_map} if device_map is not None: if low_cpu_mem_usage is None: low_cpu_mem_usage = True elif not low_cpu_mem_usage: raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") if low_cpu_mem_usage: if device_map is not None and not is_torch_version(">=", "1.10"): # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } # load config config, unused_kwargs, commit_hash = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, **kwargs, ) # Determine if we're loading from a directory of sharded checkpoints. is_sharded = False index_file = None is_local = os.path.isdir(pretrained_model_name_or_path) index_file = _fetch_index_file( is_local=is_local, pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder or "", use_safetensors=use_safetensors, cache_dir=cache_dir, variant=variant, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, user_agent=user_agent, commit_hash=commit_hash, ) if index_file is not None and index_file.is_file(): is_sharded = True if is_sharded and from_flax: raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") # load model model_file = None if from_flax: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=FLAX_WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) model = cls.from_config(config, **unused_kwargs) # Convert the weights from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model model = load_flax_checkpoint_in_pytorch_model(model, model_file) else: if is_sharded: sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( pretrained_model_name_or_path, index_file, cache_dir=cache_dir, proxies=proxies, local_files_only=local_files_only, token=token, user_agent=user_agent, revision=revision, subfolder=subfolder or "", ) elif use_safetensors and not is_sharded: try: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) except IOError as e: logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") if not allow_pickle: raise logger.warning( "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." ) if model_file is None and not is_sharded: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) if low_cpu_mem_usage: # Instantiate model with empty weights with accelerate.init_empty_weights(): model = cls.from_config(config, **unused_kwargs) # if device_map is None, load the state dict and move the params from meta device to the cpu if device_map is None and not is_sharded: param_device = "cpu" state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) if len(missing_keys) > 0: raise ValueError( f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" " those weights or else make sure your checkpoint file is correct." ) unexpected_keys = load_model_dict_into_meta( model, state_dict, device=param_device, dtype=torch_dtype, model_name_or_path=pretrained_model_name_or_path, ) if cls._keys_to_ignore_on_load_unexpected is not None: for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU force_hook = True device_map = _determine_device_map(model, device_map, max_memory, torch_dtype) if device_map is None and is_sharded: # we load the parameters on the cpu device_map = {"": "cpu"} force_hook = False try: accelerate.load_checkpoint_and_dispatch( model, model_file if not is_sharded else index_file, device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, force_hooks=force_hook, strict=True, ) except AttributeError as e: # When using accelerate loading, we do not have the ability to load the state # dict and rename the weight names manually. Additionally, accelerate skips # torch loading conventions and directly writes into `module.{_buffers, _parameters}` # (which look like they should be private variables?), so we can't use the standard hooks # to rename parameters on load. We need to mimic the original weight names so the correct # attributes are available. After we have loaded the weights, we convert the deprecated # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert # the weights so we don't have to do this again. if "'Attention' object has no attribute" in str(e): logger.warning( f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," " please also re-upload it or open a PR on the original repository." ) model._temp_convert_self_to_deprecated_attention_blocks() accelerate.load_checkpoint_and_dispatch( model, model_file if not is_sharded else index_file, device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, force_hooks=force_hook, strict=True, ) model._undo_temp_convert_self_to_deprecated_attention_blocks() else: raise e loading_info = { "missing_keys": [], "unexpected_keys": [], "mismatched_keys": [], "error_msgs": [], } else: model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, state_dict, model_file, pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, ) loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "mismatched_keys": mismatched_keys, "error_msgs": error_msgs, } if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) elif torch_dtype is not None: model = model.to(torch_dtype) model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: return model, loading_info return model @classmethod def _load_pretrained_model( cls, model, state_dict: OrderedDict, resolved_archive_file, pretrained_model_name_or_path: Union[str, os.PathLike], ignore_mismatched_sizes: bool = False, ): # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() loaded_keys = list(state_dict.keys()) expected_keys = list(model_state_dict.keys()) original_loaded_keys = loaded_keys missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) # Make sure we are able to load base models as well as derived models (with heads) model_to_load = model def _find_mismatched_keys( state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: for checkpoint_key in loaded_keys: model_key = checkpoint_key if ( model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape ): mismatched_keys.append( (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) ) del state_dict[checkpoint_key] return mismatched_keys if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, ) error_msgs = _load_state_dict_into_model(model_to_load, state_dict) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: error_msg += ( "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" " BertForPreTraining model).\n- This IS NOT expected if you are initializing" f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" " identical (initializing a BertForSequenceClassification model from a" " BertForSequenceClassification model)." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" " without further training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ] ) logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" " able to use it for predictions and inference." ) return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs @classmethod def _get_signature_keys(cls, obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) expected_modules = set(required_parameters.keys()) - {"self"} return expected_modules, optional_parameters # Adapted from `transformers` modeling_utils.py def _get_no_split_modules(self, device_map: str): """ Get the modules of the model that should not be spit when using device_map. We iterate through the modules to get the underlying `_no_split_modules`. Args: device_map (`str`): The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"] Returns: `List[str]`: List of modules that should not be split """ _no_split_modules = set() modules_to_check = [self] while len(modules_to_check) > 0: module = modules_to_check.pop(-1) # if the module does not appear in _no_split_modules, we also check the children if module.__class__.__name__ not in _no_split_modules: if isinstance(module, ModelMixin): if module._no_split_modules is None: raise ValueError( f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " "class needs to implement the `_no_split_modules` attribute." ) else: _no_split_modules = _no_split_modules | set(module._no_split_modules) modules_to_check += list(module.children()) return list(_no_split_modules) @property def device(self) -> torch.device: """ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). """ return get_parameter_device(self) @property def dtype(self) -> torch.dtype: """ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ return get_parameter_dtype(self) def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: """ Get number of (trainable or non-embedding) parameters in the module. Args: only_trainable (`bool`, *optional*, defaults to `False`): Whether or not to return only the number of trainable parameters. exclude_embeddings (`bool`, *optional*, defaults to `False`): Whether or not to return only the number of non-embedding parameters. Returns: `int`: The number of parameters. Example: ```py from diffusers import UNet2DConditionModel model_id = "runwayml/stable-diffusion-v1-5" unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet") unet.num_parameters(only_trainable=True) 859520964 ``` """ if exclude_embeddings: embedding_param_names = [ f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, torch.nn.Embedding) ] non_embedding_parameters = [ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names ] return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: deprecated_attention_block_paths = [] def recursive_find_attn_block(name, module): if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: deprecated_attention_block_paths.append(name) for sub_name, sub_module in module.named_children(): sub_name = sub_name if name == "" else f"{name}.{sub_name}" recursive_find_attn_block(sub_name, sub_module) recursive_find_attn_block("", self) # NOTE: we have to check if the deprecated parameters are in the state dict # because it is possible we are loading from a state dict that was already # converted for path in deprecated_attention_block_paths: # group_norm path stays the same # query -> to_q if f"{path}.query.weight" in state_dict: state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") if f"{path}.query.bias" in state_dict: state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") # key -> to_k if f"{path}.key.weight" in state_dict: state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") if f"{path}.key.bias" in state_dict: state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") # value -> to_v if f"{path}.value.weight" in state_dict: state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") if f"{path}.value.bias" in state_dict: state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") # proj_attn -> to_out.0 if f"{path}.proj_attn.weight" in state_dict: state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") if f"{path}.proj_attn.bias" in state_dict: state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") def _temp_convert_self_to_deprecated_attention_blocks(self) -> None: deprecated_attention_block_modules = [] def recursive_find_attn_block(module): if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: deprecated_attention_block_modules.append(module) for sub_module in module.children(): recursive_find_attn_block(sub_module) recursive_find_attn_block(self) for module in deprecated_attention_block_modules: module.query = module.to_q module.key = module.to_k module.value = module.to_v module.proj_attn = module.to_out[0] # We don't _have_ to delete the old attributes, but it's helpful to ensure # that _all_ the weights are loaded into the new attributes and we're not # making an incorrect assumption that this model should be converted when # it really shouldn't be. del module.to_q del module.to_k del module.to_v del module.to_out def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None: deprecated_attention_block_modules = [] def recursive_find_attn_block(module) -> None: if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: deprecated_attention_block_modules.append(module) for sub_module in module.children(): recursive_find_attn_block(sub_module) recursive_find_attn_block(self) for module in deprecated_attention_block_modules: module.to_q = module.query module.to_k = module.key module.to_v = module.value module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) del module.query del module.key del module.value del module.proj_attn class LegacyModelMixin(ModelMixin): r""" A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more pipeline-specific classes (like `DiTTransformer2DModel`). """ @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): # To prevent dependency import problem. from .model_loading_utils import _fetch_remapped_cls_from_config # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls. kwargs_copy = kwargs.copy() cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } # load config config, _, _ = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, **kwargs, ) # resolve remapping remapped_class = _fetch_remapped_cls_from_config(config, cls) return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy) ================================================ FILE: src/diffusers/models/normalization.py ================================================ # coding=utf-8 # Copyright 2024 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numbers from typing import Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from ..utils import is_torch_version from .activations import get_activation from .embeddings import ( CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings, ) class AdaLayerNorm(nn.Module): r""" Norm layer modified to incorporate timestep embeddings. Parameters: embedding_dim (`int`): The size of each embedding vector. num_embeddings (`int`, *optional*): The size of the embeddings dictionary. output_dim (`int`, *optional*): norm_elementwise_affine (`bool`, defaults to `False): norm_eps (`bool`, defaults to `False`): chunk_dim (`int`, defaults to `0`): """ def __init__( self, embedding_dim: int, num_embeddings: Optional[int] = None, output_dim: Optional[int] = None, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, chunk_dim: int = 0, ): super().__init__() self.chunk_dim = chunk_dim output_dim = output_dim or embedding_dim * 2 if num_embeddings is not None: self.emb = nn.Embedding(num_embeddings, embedding_dim) else: self.emb = None self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, output_dim) self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) def forward( self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None ) -> torch.Tensor: if self.emb is not None: temb = self.emb(timestep) temb = self.linear(self.silu(temb)) if self.chunk_dim == 1: # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the # other if-branch. This branch is specific to CogVideoX for now. shift, scale = temb.chunk(2, dim=1) shift = shift[:, None, :] scale = scale[:, None, :] else: scale, shift = temb.chunk(2, dim=0) x = self.norm(x) * (1 + scale) + shift return x class FP32LayerNorm(nn.LayerNorm): def forward(self, inputs: torch.Tensor) -> torch.Tensor: origin_dtype = inputs.dtype return F.layer_norm( inputs.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps, ).to(origin_dtype) class AdaLayerNormZero(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). Parameters: embedding_dim (`int`): The size of each embedding vector. num_embeddings (`int`): The size of the embeddings dictionary. """ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True): super().__init__() if num_embeddings is not None: self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) else: self.emb = None self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) elif norm_type == "fp32_layer_norm": self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) else: raise ValueError( f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." ) def forward( self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, class_labels: Optional[torch.LongTensor] = None, hidden_dtype: Optional[torch.dtype] = None, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if self.emb is not None: emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp class AdaLayerNormZeroSingle(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). Parameters: embedding_dim (`int`): The size of each embedding vector. num_embeddings (`int`): The size of the embeddings dictionary. """ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) else: raise ValueError( f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." ) def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa class LuminaRMSNormZero(nn.Module): """ Norm layer adaptive RMS normalization zero. Parameters: embedding_dim (`int`): The size of each embedding vector. """ def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear( min(embedding_dim, 1024), 4 * embedding_dim, bias=True, ) self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # emb = self.emb(timestep, encoder_hidden_states, encoder_mask) emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) return x, gate_msa, scale_mlp, gate_mlp class AdaLayerNormSingle(nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). Parameters: embedding_dim (`int`): The size of each embedding vector. use_additional_conditions (`bool`): To use additional conditions for normalization or not. """ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): super().__init__() self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions ) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) def forward( self, timestep: torch.Tensor, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, batch_size: Optional[int] = None, hidden_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # No modulation happening here. embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep class AdaGroupNorm(nn.Module): r""" GroupNorm layer modified to incorporate timestep embeddings. Parameters: embedding_dim (`int`): The size of each embedding vector. num_embeddings (`int`): The size of the embeddings dictionary. num_groups (`int`): The number of groups to separate the channels into. act_fn (`str`, *optional*, defaults to `None`): The activation function to use. eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability. """ def __init__( self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 ): super().__init__() self.num_groups = num_groups self.eps = eps if act_fn is None: self.act = None else: self.act = get_activation(act_fn) self.linear = nn.Linear(embedding_dim, out_dim * 2) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: if self.act: emb = self.act(emb) emb = self.linear(emb) emb = emb[:, :, None, None] scale, shift = emb.chunk(2, dim=1) x = F.group_norm(x, self.num_groups, eps=self.eps) x = x * (1 + scale) + shift return x class AdaLayerNormContinuous(nn.Module): def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters # because the output is immediately scaled and shifted by the projected conditioning embeddings. # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. # However, this is how it was implemented in the original code, and it's rather likely you should # set `elementwise_affine` to False. elementwise_affine=True, eps=1e-5, bias=True, norm_type="layer_norm", ): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) elif norm_type == "rms_norm": self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) else: raise ValueError(f"unknown norm_type {norm_type}") def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) scale, shift = torch.chunk(emb, 2, dim=1) x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x class LuminaLayerNormContinuous(nn.Module): def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters # because the output is immediately scaled and shifted by the projected conditioning embeddings. # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. # However, this is how it was implemented in the original code, and it's rather likely you should # set `elementwise_affine` to False. elementwise_affine=True, eps=1e-5, bias=True, norm_type="layer_norm", out_dim: Optional[int] = None, ): super().__init__() # AdaLN self.silu = nn.SiLU() self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) else: raise ValueError(f"unknown norm_type {norm_type}") # linear_2 if out_dim is not None: self.linear_2 = nn.Linear( embedding_dim, out_dim, bias=bias, ) def forward( self, x: torch.Tensor, conditioning_embedding: torch.Tensor, ) -> torch.Tensor: # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) scale = emb x = self.norm(x) * (1 + scale)[:, None, :] if self.linear_2 is not None: x = self.linear_2(x) return x class CogVideoXLayerNormZero(nn.Module): def __init__( self, conditioning_dim: int, embedding_dim: int, elementwise_affine: bool = True, eps: float = 1e-5, bias: bool = True, ) -> None: super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] if is_torch_version(">=", "2.1.0"): LayerNorm = nn.LayerNorm else: # Has optional bias parameter compared to torch layer norm # TODO: replace with torch layernorm once min required torch version >= 2.1 class LayerNorm(nn.Module): def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): super().__init__() self.eps = eps if isinstance(dim, numbers.Integral): dim = (dim,) self.dim = torch.Size(dim) if elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) self.bias = nn.Parameter(torch.zeros(dim)) if bias else None else: self.weight = None self.bias = None def forward(self, input): return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) class RMSNorm(nn.Module): def __init__(self, dim, eps: float, elementwise_affine: bool = True): super().__init__() self.eps = eps if isinstance(dim, numbers.Integral): dim = (dim,) self.dim = torch.Size(dim) if elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) else: self.weight = None def forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) if self.weight is not None: # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states * self.weight else: hidden_states = hidden_states.to(input_dtype) return hidden_states class GlobalResponseNorm(nn.Module): # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * nx) + self.beta + x ================================================ FILE: src/diffusers/models/resnet.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # `TemporalConvLayer` Copyright 2024 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from ..utils import deprecate from .activations import get_activation from .attention_processor import SpatialNorm from .downsampling import ( # noqa Downsample1D, Downsample2D, FirDownsample2D, KDownsample2D, downsample_2d, ) from .normalization import AdaGroupNorm from .upsampling import ( # noqa FirUpsample2D, KUpsample2D, Upsample1D, Upsample2D, upfirdn2d_native, upsample_2d, ) class ResnetBlockCondNorm2D(nn.Module): r""" A Resnet block that use normalization layer that incorporate conditioning information. Parameters: in_channels (`int`): The number of channels in the input. out_channels (`int`, *optional*, default to be `None`): The number of output channels for the first conv2d layer. If None, same as `in_channels`. dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. groups_out (`int`, *optional*, default to None): The number of groups to use for the second normalization layer. if set to None, same as `groups`. eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. time_embedding_norm (`str`, *optional*, default to `"ada_group"` ): The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial". kernel (`torch.Tensor`, optional, default to None): FIR filter, see [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. use_in_shortcut (`bool`, *optional*, default to `True`): If `True`, add a 1x1 nn.conv2d layer for skip-connection. up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the `conv_shortcut` output. conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. If None, same as `out_channels`. """ def __init__( self, *, in_channels: int, out_channels: Optional[int] = None, conv_shortcut: bool = False, dropout: float = 0.0, temb_channels: int = 512, groups: int = 32, groups_out: Optional[int] = None, eps: float = 1e-6, non_linearity: str = "swish", time_embedding_norm: str = "ada_group", # ada_group, spatial output_scale_factor: float = 1.0, use_in_shortcut: Optional[bool] = None, up: bool = False, down: bool = False, conv_shortcut_bias: bool = True, conv_2d_out_channels: Optional[int] = None, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.up = up self.down = down self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm if groups_out is None: groups_out = groups if self.time_embedding_norm == "ada_group": # ada_group self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) elif self.time_embedding_norm == "spatial": self.norm1 = SpatialNorm(in_channels, temb_channels) else: raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.time_embedding_norm == "ada_group": # ada_group self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) elif self.time_embedding_norm == "spatial": # spatial self.norm2 = SpatialNorm(out_channels, temb_channels) else: raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}") self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) self.upsample = self.downsample = None if self.up: self.upsample = Upsample2D(in_channels, use_conv=False) elif self.down: self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias, ) def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) hidden_states = input_tensor hidden_states = self.norm1(hidden_states, temb) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states, temb) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor class ResnetBlock2D(nn.Module): r""" A Resnet block. Parameters: in_channels (`int`): The number of channels in the input. out_channels (`int`, *optional*, default to be `None`): The number of output channels for the first conv2d layer. If None, same as `in_channels`. dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. groups_out (`int`, *optional*, default to None): The number of groups to use for the second normalization layer. if set to None, same as `groups`. eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a stronger conditioning with scale and shift. kernel (`torch.Tensor`, optional, default to None): FIR filter, see [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. use_in_shortcut (`bool`, *optional*, default to `True`): If `True`, add a 1x1 nn.conv2d layer for skip-connection. up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the `conv_shortcut` output. conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. If None, same as `out_channels`. """ def __init__( self, *, in_channels: int, out_channels: Optional[int] = None, conv_shortcut: bool = False, dropout: float = 0.0, temb_channels: int = 512, groups: int = 32, groups_out: Optional[int] = None, pre_norm: bool = True, eps: float = 1e-6, non_linearity: str = "swish", skip_time_act: bool = False, time_embedding_norm: str = "default", # default, scale_shift, kernel: Optional[torch.Tensor] = None, output_scale_factor: float = 1.0, use_in_shortcut: Optional[bool] = None, up: bool = False, down: bool = False, conv_shortcut_bias: bool = True, conv_2d_out_channels: Optional[int] = None, ): super().__init__() if time_embedding_norm == "ada_group": raise ValueError( "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead", ) if time_embedding_norm == "spatial": raise ValueError( "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead", ) self.pre_norm = True self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.up = up self.down = down self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm self.skip_time_act = skip_time_act if groups_out is None: groups_out = groups self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": self.time_emb_proj = nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels) else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") else: self.time_emb_proj = None self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) self.upsample = self.downsample = None if self.up: if kernel == "fir": fir_kernel = (1, 3, 3, 1) self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) elif kernel == "sde_vp": self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") else: self.upsample = Upsample2D(in_channels, use_conv=False) elif self.down: if kernel == "fir": fir_kernel = (1, 3, 3, 1) self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) elif kernel == "sde_vp": self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) else: self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias, ) def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) temb = self.time_emb_proj(temb)[:, :, None, None] if self.time_embedding_norm == "default": if temb is not None: hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) elif self.time_embedding_norm == "scale_shift": if temb is None: raise ValueError( f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" ) time_scale, time_shift = torch.chunk(temb, 2, dim=1) hidden_states = self.norm2(hidden_states) hidden_states = hidden_states * (1 + time_scale) + time_shift else: hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor # unet_rl.py def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor: if len(tensor.shape) == 2: return tensor[:, :, None] if len(tensor.shape) == 3: return tensor[:, :, None, :] elif len(tensor.shape) == 4: return tensor[:, :, 0, :] else: raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") class Conv1dBlock(nn.Module): """ Conv1d --> GroupNorm --> Mish Parameters: inp_channels (`int`): Number of input channels. out_channels (`int`): Number of output channels. kernel_size (`int` or `tuple`): Size of the convolving kernel. n_groups (`int`, default `8`): Number of groups to separate the channels into. activation (`str`, defaults to `mish`): Name of the activation function. """ def __init__( self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8, activation: str = "mish", ): super().__init__() self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) self.group_norm = nn.GroupNorm(n_groups, out_channels) self.mish = get_activation(activation) def forward(self, inputs: torch.Tensor) -> torch.Tensor: intermediate_repr = self.conv1d(inputs) intermediate_repr = rearrange_dims(intermediate_repr) intermediate_repr = self.group_norm(intermediate_repr) intermediate_repr = rearrange_dims(intermediate_repr) output = self.mish(intermediate_repr) return output # unet_rl.py class ResidualTemporalBlock1D(nn.Module): """ Residual 1D block with temporal convolutions. Parameters: inp_channels (`int`): Number of input channels. out_channels (`int`): Number of output channels. embed_dim (`int`): Embedding dimension. kernel_size (`int` or `tuple`): Size of the convolving kernel. activation (`str`, defaults `mish`): It is possible to choose the right activation function. """ def __init__( self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5, activation: str = "mish", ): super().__init__() self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) self.time_emb_act = get_activation(activation) self.time_emb = nn.Linear(embed_dim, out_channels) self.residual_conv = ( nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() ) def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ Args: inputs : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x out_channels x horizon ] """ t = self.time_emb_act(t) t = self.time_emb(t) out = self.conv_in(inputs) + rearrange_dims(t) out = self.conv_out(out) return out + self.residual_conv(inputs) class TemporalConvLayer(nn.Module): """ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 Parameters: in_dim (`int`): Number of input channels. out_dim (`int`): Number of output channels. dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. """ def __init__( self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32, ): super().__init__() out_dim = out_dim or in_dim self.in_dim = in_dim self.out_dim = out_dim # conv layers self.conv1 = nn.Sequential( nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv2 = nn.Sequential( nn.GroupNorm(norm_num_groups, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv3 = nn.Sequential( nn.GroupNorm(norm_num_groups, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv4 = nn.Sequential( nn.GroupNorm(norm_num_groups, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) # zero out the last layer params,so the conv block is identity nn.init.zeros_(self.conv4[-1].weight) nn.init.zeros_(self.conv4[-1].bias) def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor: hidden_states = ( hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) ) identity = hidden_states hidden_states = self.conv1(hidden_states) hidden_states = self.conv2(hidden_states) hidden_states = self.conv3(hidden_states) hidden_states = self.conv4(hidden_states) hidden_states = identity + hidden_states hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape( (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] ) return hidden_states class TemporalResnetBlock(nn.Module): r""" A Resnet block. Parameters: in_channels (`int`): The number of channels in the input. out_channels (`int`, *optional*, default to be `None`): The number of output channels for the first conv2d layer. If None, same as `in_channels`. temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. """ def __init__( self, in_channels: int, out_channels: Optional[int] = None, temb_channels: int = 512, eps: float = 1e-6, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels kernel_size = (3, 1, 1) padding = [k // 2 for k in kernel_size] self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True) self.conv1 = nn.Conv3d( in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, ) if temb_channels is not None: self.time_emb_proj = nn.Linear(temb_channels, out_channels) else: self.time_emb_proj = None self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(0.0) self.conv2 = nn.Conv3d( out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, ) self.nonlinearity = get_activation("silu") self.use_in_shortcut = self.in_channels != out_channels self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = nn.Conv3d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, ) def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: temb = self.nonlinearity(temb) temb = self.time_emb_proj(temb)[:, :, :, None, None] temb = temb.permute(0, 2, 1, 3, 4) hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = input_tensor + hidden_states return output_tensor # VideoResBlock class SpatioTemporalResBlock(nn.Module): r""" A SpatioTemporal Resnet block. Parameters: in_channels (`int`): The number of channels in the input. out_channels (`int`, *optional*, default to be `None`): The number of output channels for the first conv2d layer. If None, same as `in_channels`. temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet. temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet. merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing. merge_strategy (`str`, *optional*, defaults to `learned_with_images`): The merge strategy to use for the temporal mixing. switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`): If `True`, switch the spatial and temporal mixing. """ def __init__( self, in_channels: int, out_channels: Optional[int] = None, temb_channels: int = 512, eps: float = 1e-6, temporal_eps: Optional[float] = None, merge_factor: float = 0.5, merge_strategy="learned_with_images", switch_spatial_to_temporal_mix: bool = False, ): super().__init__() self.spatial_res_block = ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=eps, ) self.temporal_res_block = TemporalResnetBlock( in_channels=out_channels if out_channels is not None else in_channels, out_channels=out_channels if out_channels is not None else in_channels, temb_channels=temb_channels, eps=temporal_eps if temporal_eps is not None else eps, ) self.time_mixer = AlphaBlender( alpha=merge_factor, merge_strategy=merge_strategy, switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix, ) def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, ): num_frames = image_only_indicator.shape[-1] hidden_states = self.spatial_res_block(hidden_states, temb) batch_frames, channels, height, width = hidden_states.shape batch_size = batch_frames // num_frames hidden_states_mix = ( hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) ) hidden_states = ( hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) ) if temb is not None: temb = temb.reshape(batch_size, num_frames, -1) hidden_states = self.temporal_res_block(hidden_states, temb) hidden_states = self.time_mixer( x_spatial=hidden_states_mix, x_temporal=hidden_states, image_only_indicator=image_only_indicator, ) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) return hidden_states class AlphaBlender(nn.Module): r""" A module to blend spatial and temporal features. Parameters: alpha (`float`): The initial value of the blending factor. merge_strategy (`str`, *optional*, defaults to `learned_with_images`): The merge strategy to use for the temporal mixing. switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`): If `True`, switch the spatial and temporal mixing. """ strategies = ["learned", "fixed", "learned_with_images"] def __init__( self, alpha: float, merge_strategy: str = "learned_with_images", switch_spatial_to_temporal_mix: bool = False, ): super().__init__() self.merge_strategy = merge_strategy self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE if merge_strategy not in self.strategies: raise ValueError(f"merge_strategy needs to be in {self.strategies}") if self.merge_strategy == "fixed": self.register_buffer("mix_factor", torch.Tensor([alpha])) elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images": self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) else: raise ValueError(f"Unknown merge strategy {self.merge_strategy}") def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor: if self.merge_strategy == "fixed": alpha = self.mix_factor elif self.merge_strategy == "learned": alpha = torch.sigmoid(self.mix_factor) elif self.merge_strategy == "learned_with_images": if image_only_indicator is None: raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy") alpha = torch.where( image_only_indicator.bool(), torch.ones(1, 1, device=image_only_indicator.device), torch.sigmoid(self.mix_factor)[..., None], ) # (batch, channel, frames, height, width) if ndims == 5: alpha = alpha[:, None, :, None, None] # (batch*frames, height*width, channels) elif ndims == 3: alpha = alpha.reshape(-1)[:, None, None] else: raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5") else: raise NotImplementedError return alpha def forward( self, x_spatial: torch.Tensor, x_temporal: torch.Tensor, image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: alpha = self.get_alpha(image_only_indicator, x_spatial.ndim) alpha = alpha.to(x_spatial.dtype) if self.switch_spatial_to_temporal_mix: alpha = 1.0 - alpha x = alpha * x_spatial + (1.0 - alpha) * x_temporal return x ================================================ FILE: src/diffusers/models/resnet_flax.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import flax.linen as nn import jax import jax.numpy as jnp class FlaxUpsample2D(nn.Module): out_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, hidden_states): batch, height, width, channels = hidden_states.shape hidden_states = jax.image.resize( hidden_states, shape=(batch, height * 2, width * 2, channels), method="nearest", ) hidden_states = self.conv(hidden_states) return hidden_states class FlaxDownsample2D(nn.Module): out_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)), # padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states): # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim # hidden_states = jnp.pad(hidden_states, pad_width=pad) hidden_states = self.conv(hidden_states) return hidden_states class FlaxResnetBlock2D(nn.Module): in_channels: int out_channels: int = None dropout_prob: float = 0.0 use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 def setup(self): out_channels = self.in_channels if self.out_channels is None else self.out_channels self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.conv1 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.dropout = nn.Dropout(self.dropout_prob) self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut self.conv_shortcut = None if use_nin_shortcut: self.conv_shortcut = nn.Conv( out_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states, temb, deterministic=True): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) temb = self.time_emb_proj(nn.swish(temb)) temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.dropout(hidden_states, deterministic) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) return hidden_states + residual ================================================ FILE: src/diffusers/models/transformers/__init__.py ================================================ from ...utils import is_torch_available if is_torch_available(): from .auraflow_transformer_2d import AuraFlowTransformer2DModel from .cogvideox_transformer_3d import CogVideoXTransformer3DModel from .dit_transformer_2d import DiTTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel from .latte_transformer_3d import LatteTransformer3DModel from .lumina_nextdit2d import LuminaNextDiT2DModel from .pixart_transformer_2d import PixArtTransformer2DModel from .prior_transformer import PriorTransformer from .stable_audio_transformer import StableAudioDiTModel from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_flux import FluxTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel ================================================ FILE: src/diffusers/models/transformers/auraflow_transformer_2d.py ================================================ # Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Union import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention_processor import ( Attention, AttentionProcessor, AuraFlowAttnProcessor2_0, FusedAuraFlowAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormZero, FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Taken from the original aura flow inference code. def find_multiple(n: int, k: int) -> int: if n % k == 0: return n return n + k - (n % k) # Aura Flow patch embed doesn't use convs for projections. # Additionally, it uses learned positional embeddings. class AuraFlowPatchEmbed(nn.Module): def __init__( self, height=224, width=224, patch_size=16, in_channels=3, embed_dim=768, pos_embed_max_size=None, ): super().__init__() self.num_patches = (height // patch_size) * (width // patch_size) self.pos_embed_max_size = pos_embed_max_size self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1) self.patch_size = patch_size self.height, self.width = height // patch_size, width // patch_size self.base_size = height // patch_size def forward(self, latent): batch_size, num_channels, height, width = latent.size() latent = latent.view( batch_size, num_channels, height // self.patch_size, self.patch_size, width // self.patch_size, self.patch_size, ) latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) latent = self.proj(latent) return latent + self.pos_embed # Taken from the original Aura flow inference code. # Our feedforward only has GELU but Aura uses SiLU. class AuraFlowFeedForward(nn.Module): def __init__(self, dim, hidden_dim=None) -> None: super().__init__() if hidden_dim is None: hidden_dim = 4 * dim final_hidden_dim = int(2 * hidden_dim / 3) final_hidden_dim = find_multiple(final_hidden_dim, 256) self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False) self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False) self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.silu(self.linear_1(x)) * self.linear_2(x) x = self.out_projection(x) return x class AuraFlowPreFinalBlock(nn.Module): def __init__(self, embedding_dim: int, conditioning_embedding_dim: int): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False) def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) scale, shift = torch.chunk(emb, 2, dim=1) x = x * (1 + scale)[:, None, :] + shift[:, None, :] return x @maybe_allow_in_graph class AuraFlowSingleTransformerBlock(nn.Module): """Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT.""" def __init__(self, dim, num_attention_heads, attention_head_dim): super().__init__() self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") processor = AuraFlowAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="fp32_layer_norm", out_dim=dim, bias=False, out_bias=False, processor=processor, ) self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) self.ff = AuraFlowFeedForward(dim, dim * 4) def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor): residual = hidden_states # Norm + Projection. norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) # Attention. attn_output = self.attn(hidden_states=norm_hidden_states) # Process attention outputs for the `hidden_states`. hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(hidden_states) hidden_states = gate_mlp.unsqueeze(1) * ff_output hidden_states = residual + hidden_states return hidden_states @maybe_allow_in_graph class AuraFlowJointTransformerBlock(nn.Module): r""" Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive): * QK Norm in the attention blocks * No bias in the attention blocks * Most LayerNorms are in FP32 Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. is_last (`bool`): Boolean to determine if this is the last block in the model. """ def __init__(self, dim, num_attention_heads, attention_head_dim): super().__init__() self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm") processor = AuraFlowAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, added_proj_bias=False, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="fp32_layer_norm", out_dim=dim, bias=False, out_bias=False, processor=processor, context_pre_only=False, ) self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) self.ff = AuraFlowFeedForward(dim, dim * 4) self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False) self.ff_context = AuraFlowFeedForward(dim, dim * 4) def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor ): residual = hidden_states residual_context = encoder_hidden_states # Norm + Projection. norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states ) # Process attention outputs for the `hidden_states`. hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states) hidden_states = residual + hidden_states # Process attention outputs for the `encoder_hidden_states`. encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output) encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states) encoder_hidden_states = residual_context + encoder_hidden_states return encoder_hidden_states, hidden_states class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): r""" A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). Parameters: sample_size (`int`): The width of the latent images. This is fixed during training since it is used to learn a number of position embeddings. patch_size (`int`): Patch size to turn the input data into small patches. in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use. num_single_dit_layers (`int`, *optional*, defaults to 4): The number of layers of Transformer blocks to use. These blocks use concatenated image and text representations. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. out_channels (`int`, defaults to 16): Number of output channels. pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: int = 64, patch_size: int = 2, in_channels: int = 4, num_mmdit_layers: int = 4, num_single_dit_layers: int = 32, attention_head_dim: int = 256, num_attention_heads: int = 12, joint_attention_dim: int = 2048, caption_projection_dim: int = 3072, out_channels: int = 4, pos_embed_max_size: int = 1024, ): super().__init__() default_out_channels = in_channels self.out_channels = out_channels if out_channels is not None else default_out_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = AuraFlowPatchEmbed( height=self.config.sample_size, width=self.config.sample_size, patch_size=self.config.patch_size, in_channels=self.config.in_channels, embed_dim=self.inner_dim, pos_embed_max_size=pos_embed_max_size, ) self.context_embedder = nn.Linear( self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False ) self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True) self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) self.joint_transformer_blocks = nn.ModuleList( [ AuraFlowJointTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for i in range(self.config.num_mmdit_layers) ] ) self.single_transformer_blocks = nn.ModuleList( [ AuraFlowSingleTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for _ in range(self.config.num_single_dit_layers) ] ) self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) # https://arxiv.org/abs/2309.16588 # prevents artifacts in the attention maps self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02) self.gradient_checkpointing = False @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedAuraFlowAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] # Apply patch embedding, timestep embedding, and project the caption embeddings. hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype) temb = self.time_step_proj(temb) encoder_hidden_states = self.context_embedder(encoder_hidden_states) encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text) if len(self.single_transformer_blocks) > 0: encoder_seq_len = encoder_hidden_states.size(1) combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} combined_hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), combined_hidden_states, temb, **ckpt_kwargs, ) else: combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb) hidden_states = combined_hidden_states[:, encoder_seq_len:] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) # unpatchify patch_size = self.config.patch_size out_channels = self.config.out_channels height = height // patch_size width = width // patch_size hidden_states = hidden_states.reshape( shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size) ) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) ================================================ FILE: src/diffusers/models/transformers/cogvideox_transformer_3d.py ================================================ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Union import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, FeedForward from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero logger = logging.get_logger(__name__) # pylint: disable=invalid-name @maybe_allow_in_graph class CogVideoXBlock(nn.Module): r""" Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. qk_norm (`bool`, defaults to `True`): Whether or not to use normalization after query and key projections in Attention. norm_elementwise_affine (`bool`, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. norm_eps (`float`, defaults to `1e-5`): Epsilon value for normalization layers. final_dropout (`bool` defaults to `False`): Whether to apply a final dropout after the last feed-forward layer. ff_inner_dim (`int`, *optional*, defaults to `None`): Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. ff_bias (`bool`, defaults to `True`): Whether or not to use bias in Feed-forward layer. attention_out_bias (`bool`, defaults to `True`): Whether or not to use bias in Attention output projection layer. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1. Self Attention self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2. Feed Forward self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, ) -> torch.Tensor: norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) # attention text_length = norm_encoder_hidden_states.size(1) # CogVideoX uses concatenated text + video embeddings with self-attention instead of using # them in cross-attention individually norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, ) hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( hidden_states, encoder_hidden_states, temb ) # feed-forward norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] return hidden_states, encoder_hidden_states class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): """ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input. out_channels (`int`, *optional*): The number of channels in the output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. attention_bias (`bool`, *optional*): Configure if the `TransformerBlocks` attention should contain a bias parameter. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. patch_size (`int`, *optional*): The size of the patches to use in the patch embedding layer. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. num_embeds_ada_norm ( `int`, *optional*): The number of diffusion steps used during training. Pass if at least one of the norm_layers is `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. norm_type (`str`, *optional*, defaults to `"layer_norm"`): The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. norm_elementwise_affine (`bool`, *optional*, defaults to `True`): Whether or not to use elementwise affine in normalization layers. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. caption_channels (`int`, *optional*): The number of channels in the caption embeddings. video_length (`int`, *optional*): The number of frames in the video-like data. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, num_attention_heads: int = 30, attention_head_dim: int = 64, in_channels: Optional[int] = 16, out_channels: Optional[int] = 16, flip_sin_to_cos: bool = True, freq_shift: int = 0, time_embed_dim: int = 512, text_embed_dim: int = 4096, num_layers: int = 30, dropout: float = 0.0, attention_bias: bool = True, sample_width: int = 90, sample_height: int = 60, sample_frames: int = 49, patch_size: int = 2, temporal_compression_ratio: int = 4, max_text_seq_length: int = 226, activation_fn: str = "gelu-approximate", timestep_activation_fn: str = "silu", norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, spatial_interpolation_scale: float = 1.875, temporal_interpolation_scale: float = 1.0, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim post_patch_height = sample_height // patch_size post_patch_width = sample_width // patch_size post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames # 1. Patch embedding self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) self.embedding_dropout = nn.Dropout(dropout) # 2. 3D positional embeddings spatial_pos_embedding = get_3d_sincos_pos_embed( inner_dim, (post_patch_width, post_patch_height), post_time_compression_frames, spatial_interpolation_scale, temporal_interpolation_scale, ) spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) self.register_buffer("pos_embedding", pos_embedding, persistent=False) # 3. Time embeddings self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) # 4. Define spatio-temporal transformers blocks self.transformer_blocks = nn.ModuleList( [ CogVideoXBlock( dim=inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, time_embed_dim=time_embed_dim, dropout=dropout, activation_fn=activation_fn, attention_bias=attention_bias, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, ) for _ in range(num_layers) ] ) self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) # 5. Output blocks self.norm_out = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=2 * inner_dim, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1, ) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) self.gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding timesteps = timestep t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # 3. Position embedding seq_length = height * width * num_frames // (self.config.patch_size**2) pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] hidden_states = hidden_states + pos_embeds hidden_states = self.embedding_dropout(hidden_states) encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] hidden_states = hidden_states[:, self.config.max_text_seq_length :] # 5. Transformer blocks for i, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, emb, **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, ) hidden_states = self.norm_final(hidden_states) # 6. Final block hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.proj_out(hidden_states) # 7. Unpatchify p = self.config.patch_size output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) ================================================ FILE: src/diffusers/models/transformers/dit_transformer_2d.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ..attention import BasicTransformerBlock from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name class DiTTransformer2DModel(ModelMixin, ConfigMixin): r""" A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748). Parameters: num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (int, optional, defaults to 72): The number of channels in each head. in_channels (int, defaults to 4): The number of channels in the input. out_channels (int, optional): The number of channels in the output. Specify this parameter if the output channel number differs from the input. num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. norm_num_groups (int, optional, defaults to 32): Number of groups for group normalization within Transformer blocks. attention_bias (bool, optional, defaults to True): Configure if the Transformer blocks' attention should contain a bias parameter. sample_size (int, defaults to 32): The width of the latent images. This parameter is fixed during training. patch_size (int, defaults to 2): Size of the patches the model processes, relevant for architectures working on non-sequential data. activation_fn (str, optional, defaults to "gelu-approximate"): Activation function to use in feed-forward networks within Transformer blocks. num_embeds_ada_norm (int, optional, defaults to 1000): Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during inference. upcast_attention (bool, optional, defaults to False): If true, upcasts the attention mechanism dimensions for potentially improved performance. norm_type (str, optional, defaults to "ada_norm_zero"): Specifies the type of normalization used, can be 'ada_norm_zero'. norm_elementwise_affine (bool, optional, defaults to False): If true, enables element-wise affine parameters in the normalization layers. norm_eps (float, optional, defaults to 1e-5): A small constant added to the denominator in normalization layers to prevent division by zero. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 72, in_channels: int = 4, out_channels: Optional[int] = None, num_layers: int = 28, dropout: float = 0.0, norm_num_groups: int = 32, attention_bias: bool = True, sample_size: int = 32, patch_size: int = 2, activation_fn: str = "gelu-approximate", num_embeds_ada_norm: Optional[int] = 1000, upcast_attention: bool = False, norm_type: str = "ada_norm_zero", norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, ): super().__init__() # Validate inputs. if norm_type != "ada_norm_zero": raise NotImplementedError( f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." ) elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None: raise ValueError( f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." ) # Set some common variables used across the board. self.attention_head_dim = attention_head_dim self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.out_channels = in_channels if out_channels is None else out_channels self.gradient_checkpointing = False # 2. Initialize the position embedding and transformer blocks. self.height = self.config.sample_size self.width = self.config.sample_size self.patch_size = self.config.patch_size self.pos_embed = PatchEmbed( height=self.config.sample_size, width=self.config.sample_size, patch_size=self.config.patch_size, in_channels=self.config.in_channels, embed_dim=self.inner_dim, ) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, self.config.num_attention_heads, self.config.attention_head_dim, dropout=self.config.dropout, activation_fn=self.config.activation_fn, num_embeds_ada_norm=self.config.num_embeds_ada_norm, attention_bias=self.config.attention_bias, upcast_attention=self.config.upcast_attention, norm_type=norm_type, norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, ) for _ in range(self.config.num_layers) ] ) # 3. Output blocks. self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) self.proj_out_2 = nn.Linear( self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels ) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, timestep: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, return_dict: bool = True, ): """ The [`DiTTransformer2DModel`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): Input `hidden_states`. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. cross_attention_kwargs ( `Dict[str, Any]`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # 1. Input height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size hidden_states = self.pos_embed(hidden_states) # 2. Blocks for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, None, None, None, timestep, cross_attention_kwargs, class_labels, **ckpt_kwargs, ) else: hidden_states = block( hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] hidden_states = self.proj_out_2(hidden_states) # unpatchify height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) ================================================ FILE: src/diffusers/models/transformers/dual_transformer_2d.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional from torch import nn from ..modeling_outputs import Transformer2DModelOutput from .transformer_2d import Transformer2DModel class DualTransformer2DModel(nn.Module): """ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. """ def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, ): super().__init__() self.transformers = nn.ModuleList( [ Transformer2DModel( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, in_channels=in_channels, num_layers=num_layers, dropout=dropout, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attention_bias=attention_bias, sample_size=sample_size, num_vector_embeds=num_vector_embeds, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, ) for _ in range(2) ] ) # Variables that can be set by a pipeline: # The ratio of transformer1 to transformer2's output states to be combined during inference self.mix_ratio = 0.5 # The shape of `encoder_hidden_states` is expected to be # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` self.condition_lengths = [77, 257] # Which transformer to use to encode which condition. # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` self.transformer_index_for_condition = [1, 0] def forward( self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, cross_attention_kwargs=None, return_dict: bool = True, ): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. attention_mask (`torch.Tensor`, *optional*): Optional attention mask to be applied in Attention. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.transformers.transformer_2d.Transformer2DModelOutput`] or `tuple`: [`~models.transformers.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ input_states = hidden_states encoded_states = [] tokens_start = 0 # attention_mask is not used yet for i in range(2): # for each of the two transformers, pass the corresponding condition tokens condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] transformer_index = self.transformer_index_for_condition[i] encoded_state = self.transformers[transformer_index]( input_states, encoder_hidden_states=condition_state, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] encoded_states.append(encoded_state - input_states) tokens_start += self.condition_lengths[i] output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) output_states = output_states + input_states if not return_dict: return (output_states,) return Transformer2DModelOutput(sample=output_states) ================================================ FILE: src/diffusers/models/transformers/hunyuan_transformer_2d.py ================================================ # Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Optional, Union import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0 from ..embeddings import ( HunyuanCombinedTimestepTextSizeStyleEmbedding, PatchEmbed, PixArtAlphaTextProjection, ) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name class AdaLayerNormShift(nn.Module): r""" Norm layer modified to incorporate timestep embeddings. Parameters: embedding_dim (`int`): The size of each embedding vector. num_embeddings (`int`): The size of the embeddings dictionary. """ def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, embedding_dim) self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype)) x = self.norm(x) + shift.unsqueeze(dim=1) return x @maybe_allow_in_graph class HunyuanDiTBlock(nn.Module): r""" Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and QKNorm Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of headsto use for multi-head attention. cross_attention_dim (`int`,*optional*): The size of the encoder_hidden_states vector for cross attention. dropout(`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`,*optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. . norm_elementwise_affine (`bool`, *optional*, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. norm_eps (`float`, *optional*, defaults to 1e-6): A small constant added to the denominator in normalization layers to prevent division by zero. final_dropout (`bool` *optional*, defaults to False): Whether to apply a final dropout after the last feed-forward layer. ff_inner_dim (`int`, *optional*): The size of the hidden layer in the feed-forward block. Defaults to `None`. ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward block. skip (`bool`, *optional*, defaults to `False`): Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks. qk_norm (`bool`, *optional*, defaults to `True`): Whether to use normalization in QK calculation. Defaults to `True`. """ def __init__( self, dim: int, num_attention_heads: int, cross_attention_dim: int = 1024, dropout=0.0, activation_fn: str = "geglu", norm_elementwise_affine: bool = True, norm_eps: float = 1e-6, final_dropout: bool = False, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, skip: bool = False, qk_norm: bool = True, ): super().__init__() # Define 3 blocks. Each block has its own normalization layer. # NOTE: when new version comes, check norm2 and norm 3 # 1. Self-Attn self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.attn1 = Attention( query_dim=dim, cross_attention_dim=None, dim_head=dim // num_attention_heads, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=True, processor=HunyuanAttnProcessor2_0(), ) # 2. Cross-Attn self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, dim_head=dim // num_attention_heads, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=True, processor=HunyuanAttnProcessor2_0(), ) # 3. Feed-forward self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) self.ff = FeedForward( dim, dropout=dropout, ### 0.0 activation_fn=activation_fn, ### approx GeLU final_dropout=final_dropout, ### 0.0 inner_dim=ff_inner_dim, ### int(dim * mlp_ratio) bias=ff_bias, ) # 4. Skip Connection if skip: self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True) self.skip_linear = nn.Linear(2 * dim, dim) else: self.skip_linear = None # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, image_rotary_emb=None, skip=None, ) -> torch.Tensor: # Notice that normalization is always applied before the real computation in the following blocks. # 0. Long Skip Connection if self.skip_linear is not None: cat = torch.cat([hidden_states, skip], dim=-1) cat = self.skip_norm(cat) hidden_states = self.skip_linear(cat) # 1. Self-Attention norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct attn_output = self.attn1( norm_hidden_states, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + attn_output # 2. Cross-Attention hidden_states = hidden_states + self.attn2( self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb, ) # FFN Layer ### TODO: switch norm2 and norm3 in the state dict mlp_inputs = self.norm3(hidden_states) hidden_states = hidden_states + self.ff(mlp_inputs) return hidden_states class HunyuanDiT2DModel(ModelMixin, ConfigMixin): """ HunYuanDiT: Diffusion model with a Transformer backbone. Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). patch_size (`int`, *optional*): The size of the patch to use for the input. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. sample_size (`int`, *optional*): The width of the latent images. This is fixed during training since it is used to learn a number of position embeddings. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of dimension in the clip text embedding. hidden_size (`int`, *optional*): The size of hidden layer in the conditioning embedding layers. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. mlp_ratio (`float`, *optional*, defaults to 4.0): The ratio of the hidden layer size to the input size. learn_sigma (`bool`, *optional*, defaults to `True`): Whether to predict variance. cross_attention_dim_t5 (`int`, *optional*): The number dimensions in t5 text embedding. pooled_projection_dim (`int`, *optional*): The size of the pooled projection. text_len (`int`, *optional*): The length of the clip text embedding. text_len_t5 (`int`, *optional*): The length of the T5 text embedding. use_style_cond_and_image_meta_size (`bool`, *optional*): Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, patch_size: Optional[int] = None, activation_fn: str = "gelu-approximate", sample_size=32, hidden_size=1152, num_layers: int = 28, mlp_ratio: float = 4.0, learn_sigma: bool = True, cross_attention_dim: int = 1024, norm_type: str = "layer_norm", cross_attention_dim_t5: int = 2048, pooled_projection_dim: int = 1024, text_len: int = 77, text_len_t5: int = 256, use_style_cond_and_image_meta_size: bool = True, ): super().__init__() self.out_channels = in_channels * 2 if learn_sigma else in_channels self.num_heads = num_attention_heads self.inner_dim = num_attention_heads * attention_head_dim self.text_embedder = PixArtAlphaTextProjection( in_features=cross_attention_dim_t5, hidden_size=cross_attention_dim_t5 * 4, out_features=cross_attention_dim, act_fn="silu_fp32", ) self.text_embedding_padding = nn.Parameter( torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32) ) self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, in_channels=in_channels, embed_dim=hidden_size, patch_size=patch_size, pos_embed_type=None, ) self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding( hidden_size, pooled_projection_dim=pooled_projection_dim, seq_len=text_len_t5, cross_attention_dim=cross_attention_dim_t5, use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size, ) # HunyuanDiT Blocks self.blocks = nn.ModuleList( [ HunyuanDiTBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, activation_fn=activation_fn, ff_inner_dim=int(self.inner_dim * mlp_ratio), cross_attention_dim=cross_attention_dim, qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. skip=layer > num_layers // 2, ) for layer in range(num_layers) ] ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedHunyuanAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(HunyuanAttnProcessor2_0()) def forward( self, hidden_states, timestep, encoder_hidden_states=None, text_embedding_mask=None, encoder_hidden_states_t5=None, text_embedding_mask_t5=None, image_meta_size=None, style=None, image_rotary_emb=None, controlnet_block_samples=None, return_dict=True, ): """ The [`HunyuanDiT2DModel`] forward method. Args: hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): The input tensor. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. This is the output of `BertModel`. text_embedding_mask: torch.Tensor An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output of `BertModel`. encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. text_embedding_mask_t5: torch.Tensor An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output of T5 Text Encoder. image_meta_size (torch.Tensor): Conditional embedding indicate the image sizes style: torch.Tensor: Conditional embedding indicate the style image_rotary_emb (`torch.Tensor`): The image rotary embeddings to apply on query and key tensors during attention calculation. return_dict: bool Whether to return a dictionary. """ height, width = hidden_states.shape[-2:] hidden_states = self.pos_embed(hidden_states) temb = self.time_extra_emb( timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype ) # [B, D] # text projection batch_size, sequence_length, _ = encoder_hidden_states_t5.shape encoder_hidden_states_t5 = self.text_embedder( encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1]) ) encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1) encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1) text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1) text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) skips = [] for layer, block in enumerate(self.blocks): if layer > self.config.num_layers // 2: if controlnet_block_samples is not None: skip = skips.pop() + controlnet_block_samples.pop() else: skip = skips.pop() hidden_states = block( hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb, skip=skip, ) # (N, L, D) else: hidden_states = block( hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb, ) # (N, L, D) if layer < (self.config.num_layers // 2 - 1): skips.append(hidden_states) if controlnet_block_samples is not None and len(controlnet_block_samples) != 0: raise ValueError("The number of controls is not equal to the number of skip connections.") # final layer hidden_states = self.norm_out(hidden_states, temb.to(torch.float32)) hidden_states = self.proj_out(hidden_states) # (N, L, patch_size ** 2 * out_channels) # unpatchify: (N, out_channels, H, W) patch_size = self.pos_embed.patch_size height = height // patch_size width = width // patch_size hidden_states = hidden_states.reshape( shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) ) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking def disable_forward_chunking(self): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, None, 0) ================================================ FILE: src/diffusers/models/transformers/latte_transformer_3d.py ================================================ # Copyright 2024 the Latte Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid from ..attention import BasicTransformerBlock from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle class LatteTransformer3DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True """ A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code: https://github.com/Vchitect/Latte Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input. out_channels (`int`, *optional*): The number of channels in the output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. attention_bias (`bool`, *optional*): Configure if the `TransformerBlocks` attention should contain a bias parameter. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. patch_size (`int`, *optional*): The size of the patches to use in the patch embedding layer. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. num_embeds_ada_norm ( `int`, *optional*): The number of diffusion steps used during training. Pass if at least one of the norm_layers is `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. norm_type (`str`, *optional*, defaults to `"layer_norm"`): The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. norm_elementwise_affine (`bool`, *optional*, defaults to `True`): Whether or not to use elementwise affine in normalization layers. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. caption_channels (`int`, *optional*): The number of channels in the caption embeddings. video_length (`int`, *optional*): The number of frames in the video-like data. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: int = 64, patch_size: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, caption_channels: int = None, video_length: int = 16, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim # 1. Define input layers self.height = sample_size self.width = sample_size interpolation_scale = self.config.sample_size // 64 interpolation_scale = max(interpolation_scale, 1) self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, interpolation_scale=interpolation_scale, ) # 2. Define spatial transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, ) for d in range(num_layers) ] ) # 3. Define temporal transformers blocks self.temporal_transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=None, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, ) for d in range(num_layers) ] ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) # 5. Latte other blocks. self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) # define temporal positional embedding temp_pos_embed = get_1d_sincos_pos_embed_from_grid( inner_dim, torch.arange(0, video_length).unsqueeze(1) ) # 1152 hidden size self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) self.gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, timestep: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, enable_temporal_attentions: bool = True, return_dict: bool = True, ): """ The [`LatteTransformer3DModel`] forward method. Args: hidden_states shape `(batch size, channel, num_frame, height, width)`: Input `hidden_states`. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. encoder_attention_mask ( `torch.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batcheight, sequence_length)` True = keep, False = discard. * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. enable_temporal_attentions: (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # Reshape hidden states batch_size, channels, num_frame, height, width = hidden_states.shape # batch_size channels num_frame height width -> (batch_size * num_frame) channels height width hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width) # Input height, width = ( hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size, ) num_patches = height * width hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings added_cond_kwargs = {"resolution": None, "aspect_ratio": None} timestep, embedded_timestep = self.adaln_single( timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) # Prepare text embeddings for spatial block # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view( -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] ) # Prepare timesteps for spatial and temporal block timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1]) timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1]) # Spatial and temporal transformer blocks for i, (spatial_block, temp_block) in enumerate( zip(self.transformer_blocks, self.temporal_transformer_blocks) ): if self.training and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( spatial_block, hidden_states, None, # attention_mask encoder_hidden_states_spatial, encoder_attention_mask, timestep_spatial, None, # cross_attention_kwargs None, # class_labels use_reentrant=False, ) else: hidden_states = spatial_block( hidden_states, None, # attention_mask encoder_hidden_states_spatial, encoder_attention_mask, timestep_spatial, None, # cross_attention_kwargs None, # class_labels ) if enable_temporal_attentions: # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size hidden_states = hidden_states.reshape( batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] ).permute(0, 2, 1, 3) hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) if i == 0 and num_frame > 1: hidden_states = hidden_states + self.temp_pos_embed if self.training and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( temp_block, hidden_states, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, None, # cross_attention_kwargs None, # class_labels use_reentrant=False, ) else: hidden_states = temp_block( hidden_states, None, # attention_mask None, # encoder_hidden_states None, # encoder_attention_mask timestep_temp, None, # cross_attention_kwargs None, # class_labels ) # (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size hidden_states = hidden_states.reshape( batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] ).permute(0, 2, 1, 3) hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1]) shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) # unpatchify if self.adaln_single is None: height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) ) output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute( 0, 2, 1, 3, 4 ) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) ================================================ FILE: src/diffusers/models/transformers/lumina_nextdit2d.py ================================================ # Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ..attention import LuminaFeedForward from ..attention_processor import Attention, LuminaAttnProcessor2_0 from ..embeddings import ( LuminaCombinedTimestepCaptionEmbedding, LuminaPatchEmbed, ) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name class LuminaNextDiTBlock(nn.Module): """ A LuminaNextDiTBlock for LuminaNextDiT2DModel. Parameters: dim (`int`): Embedding dimension of the input features. num_attention_heads (`int`): Number of attention heads. num_kv_heads (`int`): Number of attention heads in key and value features (if using GQA), or set to None for the same as query. multiple_of (`int`): The number of multiple of ffn layer. ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension. norm_eps (`float`): The eps for norm layer. qk_norm (`bool`): normalization for query and key. cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states. norm_elementwise_affine (`bool`, *optional*, defaults to True), """ def __init__( self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, qk_norm: bool, cross_attention_dim: int, norm_elementwise_affine: bool = True, ) -> None: super().__init__() self.head_dim = dim // num_attention_heads self.gate = nn.Parameter(torch.zeros([num_attention_heads])) # Self-attention self.attn1 = Attention( query_dim=dim, cross_attention_dim=None, dim_head=dim // num_attention_heads, qk_norm="layer_norm_across_heads" if qk_norm else None, heads=num_attention_heads, kv_heads=num_kv_heads, eps=1e-5, bias=False, out_bias=False, processor=LuminaAttnProcessor2_0(), ) self.attn1.to_out = nn.Identity() # Cross-attention self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, dim_head=dim // num_attention_heads, qk_norm="layer_norm_across_heads" if qk_norm else None, heads=num_attention_heads, kv_heads=num_kv_heads, eps=1e-5, bias=False, out_bias=False, processor=LuminaAttnProcessor2_0(), ) self.feed_forward = LuminaFeedForward( dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, ) self.norm1 = LuminaRMSNormZero( embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=norm_elementwise_affine, ) self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_mask: torch.Tensor, temb: torch.Tensor, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): """ Perform a forward pass through the LuminaNextDiTBlock. Parameters: hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock. attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder. encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask. temb (`torch.Tensor`): Timestep embedding with text prompt embedding. cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention. """ residual = hidden_states # Self-attention norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) self_attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=image_rotary_emb, key_rotary_emb=image_rotary_emb, **cross_attention_kwargs, ) # Cross-attention norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) cross_attn_output = self.attn2( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, attention_mask=encoder_mask, query_rotary_emb=image_rotary_emb, key_rotary_emb=None, **cross_attention_kwargs, ) cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1) mixed_attn_output = self_attn_output + cross_attn_output mixed_attn_output = mixed_attn_output.flatten(-2) # linear proj hidden_states = self.attn2.to_out[0](mixed_attn_output) hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) return hidden_states class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): """ LuminaNextDiT: Diffusion model with a Transformer backbone. Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. Parameters: sample_size (`int`): The width of the latent images. This is fixed during training since it is used to learn a number of position embeddings. patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2): The size of each patch in the image. This parameter defines the resolution of patches fed into the model. in_channels (`int`, *optional*, defaults to 4): The number of input channels for the model. Typically, this matches the number of channels in the input images. hidden_size (`int`, *optional*, defaults to 4096): The dimensionality of the hidden layers in the model. This parameter determines the width of the model's hidden representations. num_layers (`int`, *optional*, default to 32): The number of layers in the model. This defines the depth of the neural network. num_attention_heads (`int`, *optional*, defaults to 32): The number of attention heads in each attention layer. This parameter specifies how many separate attention mechanisms are used. num_kv_heads (`int`, *optional*, defaults to 8): The number of key-value heads in the attention mechanism, if different from the number of attention heads. If None, it defaults to num_attention_heads. multiple_of (`int`, *optional*, defaults to 256): A factor that the hidden size should be a multiple of. This can help optimize certain hardware configurations. ffn_dim_multiplier (`float`, *optional*): A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on the model configuration. norm_eps (`float`, *optional*, defaults to 1e-5): A small value added to the denominator for numerical stability in normalization layers. learn_sigma (`bool`, *optional*, defaults to True): Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in predictions. qk_norm (`bool`, *optional*, defaults to True): Indicates if the queries and keys in the attention mechanism should be normalized. cross_attention_dim (`int`, *optional*, defaults to 2048): The dimensionality of the text embeddings. This parameter defines the size of the text representations used in the model. scaling_factor (`float`, *optional*, defaults to 1.0): A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the overall scale of the model's operations. """ @register_to_config def __init__( self, sample_size: int = 128, patch_size: Optional[int] = 2, in_channels: Optional[int] = 4, hidden_size: Optional[int] = 2304, num_layers: Optional[int] = 32, num_attention_heads: Optional[int] = 32, num_kv_heads: Optional[int] = None, multiple_of: Optional[int] = 256, ffn_dim_multiplier: Optional[float] = None, norm_eps: Optional[float] = 1e-5, learn_sigma: Optional[bool] = True, qk_norm: Optional[bool] = True, cross_attention_dim: Optional[int] = 2048, scaling_factor: Optional[float] = 1.0, ) -> None: super().__init__() self.sample_size = sample_size self.patch_size = patch_size self.in_channels = in_channels self.out_channels = in_channels * 2 if learn_sigma else in_channels self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.head_dim = hidden_size // num_attention_heads self.scaling_factor = scaling_factor self.patch_embedder = LuminaPatchEmbed( patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True ) self.pad_token = nn.Parameter(torch.empty(hidden_size)) self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding( hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim ) self.layers = nn.ModuleList( [ LuminaNextDiTBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, cross_attention_dim, ) for _ in range(num_layers) ] ) self.norm_out = LuminaLayerNormContinuous( embedding_dim=hidden_size, conditioning_embedding_dim=min(hidden_size, 1024), elementwise_affine=False, eps=1e-6, bias=True, out_dim=patch_size * patch_size * self.out_channels, ) # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" def forward( self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_mask: torch.Tensor, image_rotary_emb: torch.Tensor, cross_attention_kwargs: Dict[str, Any] = None, return_dict=True, ) -> torch.Tensor: """ Forward pass of LuminaNextDiT. Parameters: hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W). timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,). encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D). encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L). """ hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb) image_rotary_emb = image_rotary_emb.to(hidden_states.device) temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask) encoder_mask = encoder_mask.bool() for layer in self.layers: hidden_states = layer( hidden_states, mask, image_rotary_emb, encoder_hidden_states, encoder_mask, temb=temb, cross_attention_kwargs=cross_attention_kwargs, ) hidden_states = self.norm_out(hidden_states, temb) # unpatchify height_tokens = width_tokens = self.patch_size height, width = img_size[0] batch_size = hidden_states.size(0) sequence_length = (height // height_tokens) * (width // width_tokens) hidden_states = hidden_states[:, :sequence_length].view( batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels ) output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) ================================================ FILE: src/diffusers/models/transformers/pixart_transformer_2d.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Union import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ..attention import BasicTransformerBlock from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0 from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle logger = logging.get_logger(__name__) # pylint: disable=invalid-name class PixArtTransformer2DModel(ModelMixin, ConfigMixin): r""" A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426, https://arxiv.org/abs/2403.04692). Parameters: num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (int, optional, defaults to 72): The number of channels in each head. in_channels (int, defaults to 4): The number of channels in the input. out_channels (int, optional): The number of channels in the output. Specify this parameter if the output channel number differs from the input. num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use. dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks. norm_num_groups (int, optional, defaults to 32): Number of groups for group normalization within Transformer blocks. cross_attention_dim (int, optional): The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension. attention_bias (bool, optional, defaults to True): Configure if the Transformer blocks' attention should contain a bias parameter. sample_size (int, defaults to 128): The width of the latent images. This parameter is fixed during training. patch_size (int, defaults to 2): Size of the patches the model processes, relevant for architectures working on non-sequential data. activation_fn (str, optional, defaults to "gelu-approximate"): Activation function to use in feed-forward networks within Transformer blocks. num_embeds_ada_norm (int, optional, defaults to 1000): Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during inference. upcast_attention (bool, optional, defaults to False): If true, upcasts the attention mechanism dimensions for potentially improved performance. norm_type (str, optional, defaults to "ada_norm_zero"): Specifies the type of normalization used, can be 'ada_norm_zero'. norm_elementwise_affine (bool, optional, defaults to False): If true, enables element-wise affine parameters in the normalization layers. norm_eps (float, optional, defaults to 1e-6): A small constant added to the denominator in normalization layers to prevent division by zero. interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings. use_additional_conditions (bool, optional): If we're using additional conditions as inputs. attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used. caption_channels (int, optional, defaults to None): Number of channels to use for projecting the caption embeddings. use_linear_projection (bool, optional, defaults to False): Deprecated argument. Will be removed in a future version. num_vector_embeds (bool, optional, defaults to False): Deprecated argument. Will be removed in a future version. """ _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 72, in_channels: int = 4, out_channels: Optional[int] = 8, num_layers: int = 28, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = 1152, attention_bias: bool = True, sample_size: int = 128, patch_size: int = 2, activation_fn: str = "gelu-approximate", num_embeds_ada_norm: Optional[int] = 1000, upcast_attention: bool = False, norm_type: str = "ada_norm_single", norm_elementwise_affine: bool = False, norm_eps: float = 1e-6, interpolation_scale: Optional[int] = None, use_additional_conditions: Optional[bool] = None, caption_channels: Optional[int] = None, attention_type: Optional[str] = "default", ): super().__init__() # Validate inputs. if norm_type != "ada_norm_single": raise NotImplementedError( f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." ) elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None: raise ValueError( f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." ) # Set some common variables used across the board. self.attention_head_dim = attention_head_dim self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.out_channels = in_channels if out_channels is None else out_channels if use_additional_conditions is None: if sample_size == 128: use_additional_conditions = True else: use_additional_conditions = False self.use_additional_conditions = use_additional_conditions self.gradient_checkpointing = False # 2. Initialize the position embedding and transformer blocks. self.height = self.config.sample_size self.width = self.config.sample_size interpolation_scale = ( self.config.interpolation_scale if self.config.interpolation_scale is not None else max(self.config.sample_size // 64, 1) ) self.pos_embed = PatchEmbed( height=self.config.sample_size, width=self.config.sample_size, patch_size=self.config.patch_size, in_channels=self.config.in_channels, embed_dim=self.inner_dim, interpolation_scale=interpolation_scale, ) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, self.config.num_attention_heads, self.config.attention_head_dim, dropout=self.config.dropout, cross_attention_dim=self.config.cross_attention_dim, activation_fn=self.config.activation_fn, num_embeds_ada_norm=self.config.num_embeds_ada_norm, attention_bias=self.config.attention_bias, upcast_attention=self.config.upcast_attention, norm_type=norm_type, norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, attention_type=self.config.attention_type, ) for _ in range(self.config.num_layers) ] ) # 3. Output blocks. self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) self.adaln_single = AdaLayerNormSingle( self.inner_dim, use_additional_conditions=self.use_additional_conditions ) self.caption_projection = None if self.config.caption_channels is not None: self.caption_projection = PixArtAlphaTextProjection( in_features=self.config.caption_channels, hidden_size=self.inner_dim ) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, added_cond_kwargs: Dict[str, torch.Tensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ The [`PixArtTransformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep (`torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs. cross_attention_kwargs ( `Dict[str, Any]`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). attention_mask ( `torch.Tensor`, *optional*): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. encoder_attention_mask ( `torch.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batch, sequence_length)` True = keep, False = discard. * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if self.use_additional_conditions and added_cond_kwargs is None: raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 1. Input batch_size = hidden_states.shape[0] height, width = ( hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size, ) hidden_states = self.pos_embed(hidden_states) timestep, embedded_timestep = self.adaln_single( timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) if self.caption_projection is not None: encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) # 2. Blocks for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, None, **ckpt_kwargs, ) else: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=None, ) # 3. Output shift, scale = ( self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) ).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) # unpatchify hidden_states = hidden_states.reshape( shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) ) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) ================================================ FILE: src/diffusers/models/transformers/prior_transformer.py ================================================ from dataclasses import dataclass from typing import Dict, Optional, Union import torch import torch.nn.functional as F from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import BaseOutput from ..attention import BasicTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin @dataclass class PriorTransformerOutput(BaseOutput): """ The output of [`PriorTransformer`]. Args: predicted_image_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`): The predicted CLIP image embedding conditioned on the CLIP text embedding input. """ predicted_image_embedding: torch.Tensor class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): """ A Prior Transformer model. Parameters: num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` num_embeddings (`int`, *optional*, defaults to 77): The number of embeddings of the model input `hidden_states` additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + additional_embeddings`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. time_embed_act_fn (`str`, *optional*, defaults to 'silu'): The activation function to use to create timestep embeddings. norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before passing to Transformer blocks. Set it to `None` if normalization is not needed. embedding_proj_norm_type (`str`, *optional*, defaults to None): The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not needed. encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if `encoder_hidden_states` is `None`. added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot product between the text embedding and image embedding as proposed in the unclip paper https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. If None, will be set to `num_attention_heads * attention_head_dim` embedding_proj_dim (`int`, *optional*, default to None): The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. clip_embed_dim (`int`, *optional*, default to None): The dimension of the output. If None, will be set to `embedding_dim`. """ @register_to_config def __init__( self, num_attention_heads: int = 32, attention_head_dim: int = 64, num_layers: int = 20, embedding_dim: int = 768, num_embeddings=77, additional_embeddings=4, dropout: float = 0.0, time_embed_act_fn: str = "silu", norm_in_type: Optional[str] = None, # layer embedding_proj_norm_type: Optional[str] = None, # layer encoder_hid_proj_type: Optional[str] = "linear", # linear added_emb_type: Optional[str] = "prd", # prd time_embed_dim: Optional[int] = None, embedding_proj_dim: Optional[int] = None, clip_embed_dim: Optional[int] = None, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim self.additional_embeddings = additional_embeddings time_embed_dim = time_embed_dim or inner_dim embedding_proj_dim = embedding_proj_dim or embedding_dim clip_embed_dim = clip_embed_dim or embedding_dim self.time_proj = Timesteps(inner_dim, True, 0) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) self.proj_in = nn.Linear(embedding_dim, inner_dim) if embedding_proj_norm_type is None: self.embedding_proj_norm = None elif embedding_proj_norm_type == "layer": self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) else: raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) if encoder_hid_proj_type is None: self.encoder_hidden_states_proj = None elif encoder_hid_proj_type == "linear": self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) else: raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) if added_emb_type == "prd": self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) elif added_emb_type is None: self.prd_embedding = None else: raise ValueError( f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." ) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, activation_fn="gelu", attention_bias=True, ) for d in range(num_layers) ] ) if norm_in_type == "layer": self.norm_in = nn.LayerNorm(inner_dim) elif norm_in_type is None: self.norm_in = None else: raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") self.norm_out = nn.LayerNorm(inner_dim) self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) causal_attention_mask = torch.full( [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 ) causal_attention_mask.triu_(1) causal_attention_mask = causal_attention_mask[None, ...] self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) def forward( self, hidden_states, timestep: Union[torch.Tensor, float, int], proj_embedding: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, ): """ The [`PriorTransformer`] forward method. Args: hidden_states (`torch.Tensor` of shape `(batch_size, embedding_dim)`): The currently predicted image embeddings. timestep (`torch.LongTensor`): Current denoising step. proj_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`): Projected embedding vector the denoising process is conditioned on. encoder_hidden_states (`torch.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`): Hidden states of the text embeddings the denoising process is conditioned on. attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): Text mask for the text embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple. Returns: [`~models.transformers.prior_transformer.PriorTransformerOutput`] or `tuple`: If return_dict is True, a [`~models.transformers.prior_transformer.PriorTransformerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ batch_size = hidden_states.shape[0] timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) timesteps_projected = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might be fp16, so we need to cast here. timesteps_projected = timesteps_projected.to(dtype=self.dtype) time_embeddings = self.time_embedding(timesteps_projected) if self.embedding_proj_norm is not None: proj_embedding = self.embedding_proj_norm(proj_embedding) proj_embeddings = self.embedding_proj(proj_embedding) if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") hidden_states = self.proj_in(hidden_states) positional_embeddings = self.positional_embedding.to(hidden_states.dtype) additional_embeds = [] additional_embeddings_len = 0 if encoder_hidden_states is not None: additional_embeds.append(encoder_hidden_states) additional_embeddings_len += encoder_hidden_states.shape[1] if len(proj_embeddings.shape) == 2: proj_embeddings = proj_embeddings[:, None, :] if len(hidden_states.shape) == 2: hidden_states = hidden_states[:, None, :] additional_embeds = additional_embeds + [ proj_embeddings, time_embeddings[:, None, :], hidden_states, ] if self.prd_embedding is not None: prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) additional_embeds.append(prd_embedding) hidden_states = torch.cat( additional_embeds, dim=1, ) # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 if positional_embeddings.shape[1] < hidden_states.shape[1]: positional_embeddings = F.pad( positional_embeddings, ( 0, 0, additional_embeddings_len, self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, ), value=0.0, ) hidden_states = hidden_states + positional_embeddings if attention_mask is not None: attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) if self.norm_in is not None: hidden_states = self.norm_in(hidden_states) for block in self.transformer_blocks: hidden_states = block(hidden_states, attention_mask=attention_mask) hidden_states = self.norm_out(hidden_states) if self.prd_embedding is not None: hidden_states = hidden_states[:, -1] else: hidden_states = hidden_states[:, additional_embeddings_len:] predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) if not return_dict: return (predicted_image_embedding,) return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) def post_process_latents(self, prior_latents): prior_latents = (prior_latents * self.clip_std) + self.clip_mean return prior_latents ================================================ FILE: src/diffusers/models/transformers/stable_audio_transformer.py ================================================ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Union import numpy as np import torch import torch.nn as nn import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...models.attention import FeedForward from ...models.attention_processor import ( Attention, AttentionProcessor, StableAudioAttnProcessor2_0, ) from ...models.modeling_utils import ModelMixin from ...models.transformers.transformer_2d import Transformer2DModelOutput from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableAudioGaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__ def __init__( self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.log = log self.flip_sin_to_cos = flip_sin_to_cos if set_W_to_weight: # to delete later del self.weight self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.weight = self.W del self.W def forward(self, x): if self.log: x = torch.log(x) x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :] if self.flip_sin_to_cos: out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) else: out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return out @maybe_allow_in_graph class StableAudioDiTBlock(nn.Module): r""" Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip connection and QKNorm Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for the query states. num_key_value_attention_heads (`int`): The number of heads to use for the key and value states. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. upcast_attention (`bool`, *optional*): Whether to upcast the attention computation to float32. This is useful for mixed precision training. """ def __init__( self, dim: int, num_attention_heads: int, num_key_value_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, upcast_attention: bool = False, norm_eps: float = 1e-5, ff_inner_dim: Optional[int] = None, ): super().__init__() # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=False, upcast_attention=upcast_attention, out_bias=False, processor=StableAudioAttnProcessor2_0(), ) # 2. Cross-Attn self.norm2 = nn.LayerNorm(dim, norm_eps, True) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, kv_heads=num_key_value_attention_heads, dropout=dropout, bias=False, upcast_attention=upcast_attention, out_bias=False, processor=StableAudioAttnProcessor2_0(), ) # is self-attn if encoder_hidden_states is none # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, norm_eps, True) self.ff = FeedForward( dim, dropout=dropout, activation_fn="swiglu", final_dropout=False, inner_dim=ff_inner_dim, bias=True, ) # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, rotary_embedding: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn1( norm_hidden_states, attention_mask=attention_mask, rotary_emb=rotary_embedding, ) hidden_states = attn_output + hidden_states # 2. Cross-Attention norm_hidden_states = self.norm2(hidden_states) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, ) hidden_states = attn_output + hidden_states # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) ff_output = self.ff(norm_hidden_states) hidden_states = ff_output + hidden_states return hidden_states class StableAudioDiTModel(ModelMixin, ConfigMixin): """ The Diffusion Transformer model introduced in Stable Audio. Reference: https://github.com/Stability-AI/stable-audio-tools Parameters: sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample. in_channels (`int`, *optional*, defaults to 64): The number of channels in the input. num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states. num_key_value_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for the key and value states. out_channels (`int`, defaults to 64): Number of output channels. cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection. time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection. global_states_input_dim ( `int`, *optional*, defaults to 1536): Input dimension of the global hidden states projection. cross_attention_input_dim ( `int`, *optional*, defaults to 768): Input dimension of the cross-attention projection """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: int = 1024, in_channels: int = 64, num_layers: int = 24, attention_head_dim: int = 64, num_attention_heads: int = 24, num_key_value_attention_heads: int = 12, out_channels: int = 64, cross_attention_dim: int = 768, time_proj_dim: int = 256, global_states_input_dim: int = 1536, cross_attention_input_dim: int = 768, ): super().__init__() self.sample_size = sample_size self.out_channels = out_channels self.inner_dim = num_attention_heads * attention_head_dim self.time_proj = StableAudioGaussianFourierProjection( embedding_size=time_proj_dim // 2, flip_sin_to_cos=True, log=False, set_W_to_weight=False, ) self.timestep_proj = nn.Sequential( nn.Linear(time_proj_dim, self.inner_dim, bias=True), nn.SiLU(), nn.Linear(self.inner_dim, self.inner_dim, bias=True), ) self.global_proj = nn.Sequential( nn.Linear(global_states_input_dim, self.inner_dim, bias=False), nn.SiLU(), nn.Linear(self.inner_dim, self.inner_dim, bias=False), ) self.cross_attention_proj = nn.Sequential( nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), nn.SiLU(), nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), ) self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) self.transformer_blocks = nn.ModuleList( [ StableAudioDiTBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, num_key_value_attention_heads=num_key_value_attention_heads, attention_head_dim=attention_head_dim, cross_attention_dim=cross_attention_dim, ) for i in range(num_layers) ] ) self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False) self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False) self.gradient_checkpointing = False @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(StableAudioAttnProcessor2_0()) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.FloatTensor, timestep: torch.LongTensor = None, encoder_hidden_states: torch.FloatTensor = None, global_hidden_states: torch.FloatTensor = None, rotary_embedding: torch.FloatTensor = None, return_dict: bool = True, attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`StableAudioDiTModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`): Input `hidden_states`. timestep ( `torch.LongTensor`): Used to indicate denoising step. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`): Global embeddings that will be prepended to the hidden states. rotary_embedding (`torch.Tensor`): The rotary embeddings to apply on query and key tensors during attention calculation. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks for the two text encoders together. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating the attention masks for the two text encoders together. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) global_hidden_states = self.global_proj(global_hidden_states) time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype))) global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) hidden_states = self.preprocess_conv(hidden_states) + hidden_states # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim) hidden_states = hidden_states.transpose(1, 2) hidden_states = self.proj_in(hidden_states) # prepend global states to hidden states hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2) if attention_mask is not None: prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool) attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, attention_mask, cross_attention_hidden_states, encoder_attention_mask, rotary_embedding, **ckpt_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, attention_mask=attention_mask, encoder_hidden_states=cross_attention_hidden_states, encoder_attention_mask=encoder_attention_mask, rotary_embedding=rotary_embedding, ) hidden_states = self.proj_out(hidden_states) # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length) # remove prepend length that has been added by global hidden states hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] hidden_states = self.postprocess_conv(hidden_states) + hidden_states if not return_dict: return (hidden_states,) return Transformer2DModelOutput(sample=hidden_states) ================================================ FILE: src/diffusers/models/transformers/t5_film_transformer.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional, Tuple import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ..attention_processor import Attention from ..embeddings import get_timestep_embedding from ..modeling_utils import ModelMixin class T5FilmDecoder(ModelMixin, ConfigMixin): r""" T5 style decoder with FiLM conditioning. Args: input_dims (`int`, *optional*, defaults to `128`): The number of input dimensions. targets_length (`int`, *optional*, defaults to `256`): The length of the targets. d_model (`int`, *optional*, defaults to `768`): Size of the input hidden states. num_layers (`int`, *optional*, defaults to `12`): The number of `DecoderLayer`'s to use. num_heads (`int`, *optional*, defaults to `12`): The number of attention heads to use. d_kv (`int`, *optional*, defaults to `64`): Size of the key-value projection vectors. d_ff (`int`, *optional*, defaults to `2048`): The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s. dropout_rate (`float`, *optional*, defaults to `0.1`): Dropout probability. """ @register_to_config def __init__( self, input_dims: int = 128, targets_length: int = 256, max_decoder_noise_time: float = 2000.0, d_model: int = 768, num_layers: int = 12, num_heads: int = 12, d_kv: int = 64, d_ff: int = 2048, dropout_rate: float = 0.1, ): super().__init__() self.conditioning_emb = nn.Sequential( nn.Linear(d_model, d_model * 4, bias=False), nn.SiLU(), nn.Linear(d_model * 4, d_model * 4, bias=False), nn.SiLU(), ) self.position_encoding = nn.Embedding(targets_length, d_model) self.position_encoding.weight.requires_grad = False self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) self.dropout = nn.Dropout(p=dropout_rate) self.decoders = nn.ModuleList() for lyr_num in range(num_layers): # FiLM conditional T5 decoder lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) self.decoders.append(lyr) self.decoder_norm = T5LayerNorm(d_model) self.post_dropout = nn.Dropout(p=dropout_rate) self.spec_out = nn.Linear(d_model, input_dims, bias=False) def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor: mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) return mask.unsqueeze(-3) def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): batch, _, _ = decoder_input_tokens.shape assert decoder_noise_time.shape == (batch,) # decoder_noise_time is in [0, 1), so rescale to expected timing range. time_steps = get_timestep_embedding( decoder_noise_time * self.config.max_decoder_noise_time, embedding_dim=self.config.d_model, max_period=self.config.max_decoder_noise_time, ).to(dtype=self.dtype) conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) seq_length = decoder_input_tokens.shape[1] # If we want to use relative positions for audio context, we can just offset # this sequence by the length of encodings_and_masks. decoder_positions = torch.broadcast_to( torch.arange(seq_length, device=decoder_input_tokens.device), (batch, seq_length), ) position_encodings = self.position_encoding(decoder_positions) inputs = self.continuous_inputs_projection(decoder_input_tokens) inputs += position_encodings y = self.dropout(inputs) # decoder: No padding present. decoder_mask = torch.ones( decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype ) # Translate encoding masks to encoder-decoder masks. encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] # cross attend style: concat encodings encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) for lyr in self.decoders: y = lyr( y, conditioning_emb=conditioning_emb, encoder_hidden_states=encoded, encoder_attention_mask=encoder_decoder_mask, )[0] y = self.decoder_norm(y) y = self.post_dropout(y) spec_out = self.spec_out(y) return spec_out class DecoderLayer(nn.Module): r""" T5 decoder layer. Args: d_model (`int`): Size of the input hidden states. d_kv (`int`): Size of the key-value projection vectors. num_heads (`int`): Number of attention heads. d_ff (`int`): Size of the intermediate feed-forward layer. dropout_rate (`float`): Dropout probability. layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`): A small value used for numerical stability to avoid dividing by zero. """ def __init__( self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6 ): super().__init__() self.layer = nn.ModuleList() # cond self attention: layer 0 self.layer.append( T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) ) # cross attention: layer 1 self.layer.append( T5LayerCrossAttention( d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon, ) ) # Film Cond MLP + dropout: last layer self.layer.append( T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) ) def forward( self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, encoder_decoder_position_bias=None, ) -> Tuple[torch.Tensor]: hidden_states = self.layer[0]( hidden_states, conditioning_emb=conditioning_emb, attention_mask=attention_mask, ) if encoder_hidden_states is not None: encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( encoder_hidden_states.dtype ) hidden_states = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_extended_attention_mask, ) # Apply Film Conditional Feed Forward layer hidden_states = self.layer[-1](hidden_states, conditioning_emb) return (hidden_states,) class T5LayerSelfAttentionCond(nn.Module): r""" T5 style self-attention layer with conditioning. Args: d_model (`int`): Size of the input hidden states. d_kv (`int`): Size of the key-value projection vectors. num_heads (`int`): Number of attention heads. dropout_rate (`float`): Dropout probability. """ def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float): super().__init__() self.layer_norm = T5LayerNorm(d_model) self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) self.dropout = nn.Dropout(dropout_rate) def forward( self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: # pre_self_attention_layer_norm normed_hidden_states = self.layer_norm(hidden_states) if conditioning_emb is not None: normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) # Self-attention block attention_output = self.attention(normed_hidden_states) hidden_states = hidden_states + self.dropout(attention_output) return hidden_states class T5LayerCrossAttention(nn.Module): r""" T5 style cross-attention layer. Args: d_model (`int`): Size of the input hidden states. d_kv (`int`): Size of the key-value projection vectors. num_heads (`int`): Number of attention heads. dropout_rate (`float`): Dropout probability. layer_norm_epsilon (`float`): A small value used for numerical stability to avoid dividing by zero. """ def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float): super().__init__() self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.dropout = nn.Dropout(dropout_rate) def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( normed_hidden_states, encoder_hidden_states=key_value_states, attention_mask=attention_mask.squeeze(1), ) layer_output = hidden_states + self.dropout(attention_output) return layer_output class T5LayerFFCond(nn.Module): r""" T5 style feed-forward conditional layer. Args: d_model (`int`): Size of the input hidden states. d_ff (`int`): Size of the intermediate feed-forward layer. dropout_rate (`float`): Dropout probability. layer_norm_epsilon (`float`): A small value used for numerical stability to avoid dividing by zero. """ def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float): super().__init__() self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.dropout = nn.Dropout(dropout_rate) def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor: forwarded_states = self.layer_norm(hidden_states) if conditioning_emb is not None: forwarded_states = self.film(forwarded_states, conditioning_emb) forwarded_states = self.DenseReluDense(forwarded_states) hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states class T5DenseGatedActDense(nn.Module): r""" T5 style feed-forward layer with gated activations and dropout. Args: d_model (`int`): Size of the input hidden states. d_ff (`int`): Size of the intermediate feed-forward layer. dropout_rate (`float`): Dropout probability. """ def __init__(self, d_model: int, d_ff: int, dropout_rate: float): super().__init__() self.wi_0 = nn.Linear(d_model, d_ff, bias=False) self.wi_1 = nn.Linear(d_model, d_ff, bias=False) self.wo = nn.Linear(d_ff, d_model, bias=False) self.dropout = nn.Dropout(dropout_rate) self.act = NewGELUActivation() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) hidden_states = self.wo(hidden_states) return hidden_states class T5LayerNorm(nn.Module): r""" T5 style layer normalization module. Args: hidden_size (`int`): Size of the input hidden states. eps (`float`, `optional`, defaults to `1e-6`): A small value used for numerical stability to avoid dividing by zero. """ def __init__(self, hidden_size: int, eps: float = 1e-6): """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states class NewGELUActivation(nn.Module): """ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 """ def forward(self, input: torch.Tensor) -> torch.Tensor: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) class T5FiLMLayer(nn.Module): """ T5 style FiLM Layer. Args: in_features (`int`): Number of input features. out_features (`int`): Number of output features. """ def __init__(self, in_features: int, out_features: int): super().__init__() self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor: emb = self.scale_bias(conditioning_emb) scale, shift = torch.chunk(emb, 2, -1) x = x * (1 + scale) + shift return x ================================================ FILE: src/diffusers/models/transformers/transformer_2d.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import nn from ...configuration_utils import LegacyConfigMixin, register_to_config from ...utils import deprecate, is_torch_version, logging from ..attention import BasicTransformerBlock from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import LegacyModelMixin from ..normalization import AdaLayerNormSingle logger = logging.get_logger(__name__) # pylint: disable=invalid-name class Transformer2DModelOutput(Transformer2DModelOutput): def __init__(self, *args, **kwargs): deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead." deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) super().__init__(*args, **kwargs) class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): """ A 2D Transformer model for image-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. num_vector_embeds (`int`, *optional*): The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. num_embeds_ada_norm ( `int`, *optional*): The number of diffusion steps used during training. Pass if at least one of the norm_layers is `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the `TransformerBlocks` attention should contain a bias parameter. """ _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock"] @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, attention_type: str = "default", caption_channels: int = None, interpolation_scale: float = None, use_additional_conditions: Optional[bool] = None, ): super().__init__() # Validate inputs. if patch_size is not None: if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: raise NotImplementedError( f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." ) elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None: raise ValueError( f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." ) # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_vectorized = num_vector_embeds is not None self.is_input_patches = in_channels is not None and patch_size is not None if self.is_input_continuous and self.is_input_vectorized: raise ValueError( f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" " sure that either `in_channels` or `num_vector_embeds` is None." ) elif self.is_input_vectorized and self.is_input_patches: raise ValueError( f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" " sure that either `num_vector_embeds` or `num_patches` is None." ) elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: raise ValueError( f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." ) if norm_type == "layer_norm" and num_embeds_ada_norm is not None: deprecation_message = ( f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config." " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" " would be very nice if you could open a Pull request for the `transformer/config.json` file" ) deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) norm_type = "ada_norm" # Set some common variables used across the board. self.use_linear_projection = use_linear_projection self.interpolation_scale = interpolation_scale self.caption_channels = caption_channels self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.gradient_checkpointing = False if use_additional_conditions is None: if norm_type == "ada_norm_single" and sample_size == 128: use_additional_conditions = True else: use_additional_conditions = False self.use_additional_conditions = use_additional_conditions # 2. Initialize the right blocks. # These functions follow a common structure: # a. Initialize the input blocks. b. Initialize the transformer blocks. # c. Initialize the output blocks and other projection blocks when necessary. if self.is_input_continuous: self._init_continuous_input(norm_type=norm_type) elif self.is_input_vectorized: self._init_vectorized_inputs(norm_type=norm_type) elif self.is_input_patches: self._init_patched_inputs(norm_type=norm_type) def _init_continuous_input(self, norm_type): self.norm = torch.nn.GroupNorm( num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True ) if self.use_linear_projection: self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim) else: self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, self.config.num_attention_heads, self.config.attention_head_dim, dropout=self.config.dropout, cross_attention_dim=self.config.cross_attention_dim, activation_fn=self.config.activation_fn, num_embeds_ada_norm=self.config.num_embeds_ada_norm, attention_bias=self.config.attention_bias, only_cross_attention=self.config.only_cross_attention, double_self_attention=self.config.double_self_attention, upcast_attention=self.config.upcast_attention, norm_type=norm_type, norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, attention_type=self.config.attention_type, ) for _ in range(self.config.num_layers) ] ) if self.use_linear_projection: self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels) else: self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0) def _init_vectorized_inputs(self, norm_type): assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert ( self.config.num_vector_embeds is not None ), "Transformer2DModel over discrete input must provide num_embed" self.height = self.config.sample_size self.width = self.config.sample_size self.num_latent_pixels = self.height * self.width self.latent_image_embedding = ImagePositionalEmbeddings( num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width ) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, self.config.num_attention_heads, self.config.attention_head_dim, dropout=self.config.dropout, cross_attention_dim=self.config.cross_attention_dim, activation_fn=self.config.activation_fn, num_embeds_ada_norm=self.config.num_embeds_ada_norm, attention_bias=self.config.attention_bias, only_cross_attention=self.config.only_cross_attention, double_self_attention=self.config.double_self_attention, upcast_attention=self.config.upcast_attention, norm_type=norm_type, norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, attention_type=self.config.attention_type, ) for _ in range(self.config.num_layers) ] ) self.norm_out = nn.LayerNorm(self.inner_dim) self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1) def _init_patched_inputs(self, norm_type): assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size" self.height = self.config.sample_size self.width = self.config.sample_size self.patch_size = self.config.patch_size interpolation_scale = ( self.config.interpolation_scale if self.config.interpolation_scale is not None else max(self.config.sample_size // 64, 1) ) self.pos_embed = PatchEmbed( height=self.config.sample_size, width=self.config.sample_size, patch_size=self.config.patch_size, in_channels=self.in_channels, embed_dim=self.inner_dim, interpolation_scale=interpolation_scale, ) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, self.config.num_attention_heads, self.config.attention_head_dim, dropout=self.config.dropout, cross_attention_dim=self.config.cross_attention_dim, activation_fn=self.config.activation_fn, num_embeds_ada_norm=self.config.num_embeds_ada_norm, attention_bias=self.config.attention_bias, only_cross_attention=self.config.only_cross_attention, double_self_attention=self.config.double_self_attention, upcast_attention=self.config.upcast_attention, norm_type=norm_type, norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, attention_type=self.config.attention_type, ) for _ in range(self.config.num_layers) ] ) if self.config.norm_type != "ada_norm_single": self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) self.proj_out_2 = nn.Linear( self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels ) elif self.config.norm_type == "ada_norm_single": self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) self.proj_out = nn.Linear( self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels ) # PixArt-Alpha blocks. self.adaln_single = None if self.config.norm_type == "ada_norm_single": # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use # additional conditions until we find better name self.adaln_single = AdaLayerNormSingle( self.inner_dim, use_additional_conditions=self.use_additional_conditions ) self.caption_projection = None if self.caption_channels is not None: self.caption_projection = PixArtAlphaTextProjection( in_features=self.caption_channels, hidden_size=self.inner_dim ) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, added_cond_kwargs: Dict[str, torch.Tensor] = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ The [`Transformer2DModel`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): Input `hidden_states`. encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. cross_attention_kwargs ( `Dict[str, Any]`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). attention_mask ( `torch.Tensor`, *optional*): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. encoder_attention_mask ( `torch.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batch, sequence_length)` True = keep, False = discard. * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 1. Input if self.is_input_continuous: batch_size, _, height, width = hidden_states.shape residual = hidden_states hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( hidden_states, encoder_hidden_states, timestep, added_cond_kwargs ) # 2. Blocks for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, **ckpt_kwargs, ) else: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output if self.is_input_continuous: output = self._get_output_for_continuous_inputs( hidden_states=hidden_states, residual=residual, batch_size=batch_size, height=height, width=width, inner_dim=inner_dim, ) elif self.is_input_vectorized: output = self._get_output_for_vectorized_inputs(hidden_states) elif self.is_input_patches: output = self._get_output_for_patched_inputs( hidden_states=hidden_states, timestep=timestep, class_labels=class_labels, embedded_timestep=embedded_timestep, height=height, width=width, ) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) def _operate_on_continuous_inputs(self, hidden_states): batch, _, height, width = hidden_states.shape hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) return hidden_states, inner_dim def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs): batch_size = hidden_states.shape[0] hidden_states = self.pos_embed(hidden_states) embedded_timestep = None if self.adaln_single is not None: if self.use_additional_conditions and added_cond_kwargs is None: raise ValueError( "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." ) timestep, embedded_timestep = self.adaln_single( timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) if self.caption_projection is not None: encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) return hidden_states, encoder_hidden_states, timestep, embedded_timestep def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): if not self.use_linear_projection: hidden_states = ( hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() ) hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() ) output = hidden_states + residual return output def _get_output_for_vectorized_inputs(self, hidden_states): hidden_states = self.norm_out(hidden_states) logits = self.out(hidden_states) # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) logits = logits.permute(0, 2, 1) # log(p(x_0)) output = F.log_softmax(logits.double(), dim=1).float() return output def _get_output_for_patched_inputs( self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None ): if self.config.norm_type != "ada_norm_single": conditioning = self.transformer_blocks[0].norm1.emb( timestep, class_labels, hidden_dtype=hidden_states.dtype ) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] hidden_states = self.proj_out_2(hidden_states) elif self.config.norm_type == "ada_norm_single": shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) # unpatchify if self.adaln_single is None: height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) return output ================================================ FILE: src/diffusers/models/transformers/transformer_flux.py ================================================ # Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import FeedForward from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0 from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings from ..modeling_outputs import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # YiYi to-do: refactor rope related functions/classes def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim omega = 1.0 / (theta**scale) batch_size, seq_length = pos.shape out = torch.einsum("...n,d->...nd", pos, omega) cos_out = torch.cos(out) sin_out = torch.sin(out) stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) return out.float() # YiYi to-do: refactor rope related functions/classes class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: List[int]): super().__init__() self.dim = dim self.theta = theta self.axes_dim = axes_dim def forward(self, ids: torch.Tensor) -> torch.Tensor: n_axes = ids.shape[-1] emb = torch.cat( [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3, ) return emb.unsqueeze(1) @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): r""" A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the processing of `context` conditions. """ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) self.norm = AdaLayerNormZeroSingle(dim) self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) processor = FluxSingleAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, processor=processor, qk_norm="rms_norm", eps=1e-6, pre_only=True, ) def forward( self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, step_index=None, index_block=None, single_copy_blocks=None, single_skip_blocks=None, ): if single_skip_blocks is None or index_block not in single_skip_blocks: residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, step_index=step_index, index_block=index_block, single_copy_blocks=single_copy_blocks, ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) gate = gate.unsqueeze(1) hidden_states = gate * self.proj_out(hidden_states) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) return hidden_states @maybe_allow_in_graph class FluxTransformerBlock(nn.Module): r""" A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the processing of `context` conditions. """ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): super().__init__() self.norm1 = AdaLayerNormZero(dim) self.norm1_context = AdaLayerNormZero(dim) if hasattr(F, "scaled_dot_product_attention"): processor = FluxAttnProcessor2_0() else: raise ValueError( "The current PyTorch version does not support the `scaled_dot_product_attention` function." ) self.attn = Attention( query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, processor=processor, qk_norm=qk_norm, eps=eps, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, step_index=None, index_block=None, mm_copy_blocks=None, mm_skip_blocks=None, ): if mm_skip_blocks is None or index_block not in mm_skip_blocks: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, step_index=step_index, index_block=index_block, mm_copy_blocks=mm_copy_blocks, ) # Process attention outputs for the `hidden_states`. attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output # Process attention outputs for the `encoder_hidden_states`. context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) return encoder_hidden_states, hidden_states class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): """ The Transformer model introduced in Flux. Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Parameters: patch_size (`int`): Patch size to turn the input data into small patches. in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, patch_size: int = 1, in_channels: int = 64, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, axes_dims_rope: List[int] = [16, 56, 56], ): super().__init__() self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ) self.time_text_embed = text_time_guidance_cls( embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim ) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ FluxTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for i in range(self.config.num_layers) ] ) self.single_transformer_blocks = nn.ModuleList( [ FluxSingleTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for i in range(self.config.num_single_layers) ] ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, step_index: int = None, mm_copy_blocks=None, single_copy_blocks=None, mm_skip_blocks=None, single_skip_blocks=None, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 else: guidance = None temb = ( self.time_text_embed(timestep, pooled_projections) if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) ids = torch.cat((txt_ids, img_ids), dim=1) image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, image_rotary_emb, **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, step_index=step_index, index_block=index_block, mm_copy_blocks=mm_copy_blocks, mm_skip_blocks=mm_skip_blocks, ) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, temb, image_rotary_emb, **ckpt_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, step_index=step_index, index_block=index_block, single_copy_blocks=single_copy_blocks, single_skip_blocks=single_skip_blocks, ) hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) ================================================ FILE: src/diffusers/models/transformers/transformer_sd3.py ================================================ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Optional, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.attention import JointTransformerBlock from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): """ The Transformer model introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 Parameters: sample_size (`int`): The width of the latent images. This is fixed during training since it is used to learn a number of position embeddings. patch_size (`int`): Patch size to turn the input data into small patches. in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. out_channels (`int`, defaults to 16): Number of output channels. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: int = 128, patch_size: int = 2, in_channels: int = 16, num_layers: int = 18, attention_head_dim: int = 64, num_attention_heads: int = 18, joint_attention_dim: int = 4096, caption_projection_dim: int = 1152, pooled_projection_dim: int = 2048, out_channels: int = 16, pos_embed_max_size: int = 96, ): super().__init__() default_out_channels = in_channels self.out_channels = out_channels if out_channels is not None else default_out_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = PatchEmbed( height=self.config.sample_size, width=self.config.sample_size, patch_size=self.config.patch_size, in_channels=self.config.in_channels, embed_dim=self.inner_dim, pos_embed_max_size=pos_embed_max_size, # hard-code for now. ) self.time_text_embed = CombinedTimestepTextProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim ) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) # `attention_head_dim` is doubled to account for the mixing. # It needs to crafted when we get the actual checkpoints. self.transformer_blocks = nn.ModuleList( [ JointTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, context_pre_only=i == num_layers - 1, ) for i in range(self.config.num_layers) ] ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking def disable_forward_chunking(self): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, None, 0) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedJointAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, pooled_projections: torch.FloatTensor = None, timestep: torch.LongTensor = None, block_controlnet_hidden_states: List = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`SD3Transformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) height, width = hidden_states.shape[-2:] hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too. temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) for index_block, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) # controlnet residual if block_controlnet_hidden_states is not None and block.context_pre_only is False: interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) # unpatchify patch_size = self.config.patch_size height = height // patch_size width = width // patch_size hidden_states = hidden_states.reshape( shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) ) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) ================================================ FILE: src/diffusers/models/transformers/transformer_temporal.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, Optional import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..resnet import AlphaBlender @dataclass class TransformerTemporalModelOutput(BaseOutput): """ The output of [`TransformerTemporalModel`]. Args: sample (`torch.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. """ sample: torch.Tensor class TransformerTemporalModel(ModelMixin, ConfigMixin): """ A Transformer model for video-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. attention_bias (`bool`, *optional*): Configure if the `TransformerBlock` attention should contain a bias parameter. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported activation functions. norm_elementwise_affine (`bool`, *optional*): Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. double_self_attention (`bool`, *optional*): Configure if each `TransformerBlock` should contain two self-attention layers. positional_embeddings: (`str`, *optional*): The type of positional embeddings to apply to the sequence input before passing use. num_positional_embeddings: (`int`, *optional*): The maximum length of the sequence over which to apply positional embeddings. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, activation_fn: str = "geglu", norm_elementwise_affine: bool = True, double_self_attention: bool = True, positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, double_self_attention=double_self_attention, norm_elementwise_affine=norm_elementwise_affine, positional_embeddings=positional_embeddings, num_positional_embeddings=num_positional_embeddings, ) for d in range(num_layers) ] ) self.proj_out = nn.Linear(inner_dim, in_channels) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.LongTensor] = None, timestep: Optional[torch.LongTensor] = None, class_labels: torch.LongTensor = None, num_frames: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> TransformerTemporalModelOutput: """ The [`TransformerTemporal`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): Input hidden_states. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. num_frames (`int`, *optional*, defaults to 1): The number of frames to be processed per batch. This is used to reshape the hidden states. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain tuple. Returns: [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: If `return_dict` is True, an [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # 1. Input batch_frames, channel, height, width = hidden_states.shape batch_size = batch_frames // num_frames residual = hidden_states hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) hidden_states = hidden_states.permute(0, 2, 1, 3, 4) hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, num_frames, channel) .permute(0, 3, 4, 1, 2) .contiguous() ) hidden_states = hidden_states.reshape(batch_frames, channel, height, width) output = hidden_states + residual if not return_dict: return (output,) return TransformerTemporalModelOutput(sample=output) class TransformerSpatioTemporalModel(nn.Module): """ A Transformer model for video-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). out_channels (`int`, *optional*): The number of channels in the output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. """ def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: int = 320, out_channels: Optional[int] = None, num_layers: int = 1, cross_attention_dim: Optional[int] = None, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim self.inner_dim = inner_dim # 2. Define input layers self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) self.proj_in = nn.Linear(in_channels, inner_dim) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, cross_attention_dim=cross_attention_dim, ) for d in range(num_layers) ] ) time_mix_inner_dim = inner_dim self.temporal_transformer_blocks = nn.ModuleList( [ TemporalBasicTransformerBlock( inner_dim, time_mix_inner_dim, num_attention_heads, attention_head_dim, cross_attention_dim=cross_attention_dim, ) for _ in range(num_layers) ] ) time_embed_dim = in_channels * 4 self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) self.time_proj = Timesteps(in_channels, True, 0) self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels # TODO: should use out_channels for continuous projections self.proj_out = nn.Linear(inner_dim, in_channels) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ Args: hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states. num_frames (`int`): The number of frames to be processed per batch. This is used to reshape the hidden states. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*): A tensor indicating whether the input contains only images. 1 indicates that the input contains only images, 0 indicates that the input contains video frames. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain tuple. Returns: [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: If `return_dict` is True, an [`~models.transformers.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # 1. Input batch_frames, _, height, width = hidden_states.shape num_frames = image_only_indicator.shape[-1] batch_size = batch_frames // num_frames time_context = encoder_hidden_states time_context_first_timestep = time_context[None, :].reshape( batch_size, num_frames, -1, time_context.shape[-1] )[:, 0] time_context = time_context_first_timestep[:, None].broadcast_to( batch_size, height * width, time_context.shape[-2], time_context.shape[-1] ) time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1]) residual = hidden_states hidden_states = self.norm(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) num_frames_emb = torch.arange(num_frames, device=hidden_states.device) num_frames_emb = num_frames_emb.repeat(batch_size, 1) num_frames_emb = num_frames_emb.reshape(-1) t_emb = self.time_proj(num_frames_emb) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_pos_embed(t_emb) emb = emb[:, None, :] # 2. Blocks for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): if self.training and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( block, hidden_states, None, encoder_hidden_states, None, use_reentrant=False, ) else: hidden_states = block( hidden_states, encoder_hidden_states=encoder_hidden_states, ) hidden_states_mix = hidden_states hidden_states_mix = hidden_states_mix + emb hidden_states_mix = temporal_block( hidden_states_mix, num_frames=num_frames, encoder_hidden_states=time_context, ) hidden_states = self.time_mixer( x_spatial=hidden_states, x_temporal=hidden_states_mix, image_only_indicator=image_only_indicator, ) # 3. Output hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual if not return_dict: return (output,) return TransformerTemporalModelOutput(sample=output) ================================================ FILE: src/diffusers/models/unets/__init__.py ================================================ from ...utils import is_flax_available, is_torch_available if is_torch_available(): from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel from .unet_i2vgen_xl import I2VGenXLUNet from .unet_kandinsky3 import Kandinsky3UNet from .unet_motion_model import MotionAdapter, UNetMotionModel from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel from .unet_stable_cascade import StableCascadeUNet from .uvit_2d import UVit2DModel if is_flax_available(): from .unet_2d_condition_flax import FlaxUNet2DConditionModel ================================================ FILE: src/diffusers/models/unets/unet_1d.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block @dataclass class UNet1DOutput(BaseOutput): """ The output of [`UNet1DModel`]. Args: sample (`torch.Tensor` of shape `(batch_size, num_channels, sample_size)`): The hidden states output from the last layer of the model. """ sample: torch.Tensor class UNet1DModel(ModelMixin, ConfigMixin): r""" A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. extra_in_channels (`int`, *optional*, defaults to 0): Number of additional channels to be added to the input of the first down block. Useful for cases where the input data has more channels than what the model was initially designed for. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sin to cos for Fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`): Tuple of block output channels. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. downsample_each_block (`int`, *optional*, defaults to `False`): Experimental feature for using a UNet without upsampling. """ @register_to_config def __init__( self, sample_size: int = 65536, sample_rate: Optional[int] = None, in_channels: int = 2, out_channels: int = 2, extra_in_channels: int = 0, time_embedding_type: str = "fourier", flip_sin_to_cos: bool = True, use_timestep_embedding: bool = False, freq_shift: float = 0.0, down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), mid_block_type: Tuple[str] = "UNetMidBlock1D", out_block_type: str = None, block_out_channels: Tuple[int] = (32, 32, 64), act_fn: str = None, norm_num_groups: int = 8, layers_per_block: int = 1, downsample_each_block: bool = False, ): super().__init__() self.sample_size = sample_size # time if time_embedding_type == "fourier": self.time_proj = GaussianFourierProjection( embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": self.time_proj = Timesteps( block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift ) timestep_input_dim = block_out_channels[0] if use_timestep_embedding: time_embed_dim = block_out_channels[0] * 4 self.time_mlp = TimestepEmbedding( in_channels=timestep_input_dim, time_embed_dim=time_embed_dim, act_fn=act_fn, out_dim=block_out_channels[0], ) self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) self.out_block = None # down output_channel = in_channels for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] if i == 0: input_channel += extra_in_channels is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=block_out_channels[0], add_downsample=not is_final_block or downsample_each_block, ) self.down_blocks.append(down_block) # mid self.mid_block = get_mid_block( mid_block_type, in_channels=block_out_channels[-1], mid_channels=block_out_channels[-1], out_channels=block_out_channels[-1], embed_dim=block_out_channels[0], num_layers=layers_per_block, add_downsample=downsample_each_block, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] if out_block_type is None: final_upsample_channels = out_channels else: final_upsample_channels = block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = ( reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels ) is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=layers_per_block, in_channels=prev_output_channel, out_channels=output_channel, temb_channels=block_out_channels[0], add_upsample=not is_final_block, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.out_block = get_out_block( out_block_type=out_block_type, num_groups_out=num_groups_out, embed_dim=block_out_channels[0], out_channels=out_channels, act_fn=act_fn, fc_dim=block_out_channels[-1] // 4, ) def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], return_dict: bool = True, ) -> Union[UNet1DOutput, Tuple]: r""" The [`UNet1DModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_1d.UNet1DOutput`] instead of a plain tuple. Returns: [`~models.unets.unet_1d.UNet1DOutput`] or `tuple`: If `return_dict` is True, an [`~models.unets.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # 1. time timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) timestep_embed = self.time_proj(timesteps) if self.config.use_timestep_embedding: timestep_embed = self.time_mlp(timestep_embed) else: timestep_embed = timestep_embed[..., None] timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) # 2. down down_block_res_samples = () for downsample_block in self.down_blocks: sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) down_block_res_samples += res_samples # 3. mid if self.mid_block: sample = self.mid_block(sample, timestep_embed) # 4. up for i, upsample_block in enumerate(self.up_blocks): res_samples = down_block_res_samples[-1:] down_block_res_samples = down_block_res_samples[:-1] sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) # 5. post-process if self.out_block: sample = self.out_block(sample, timestep_embed) if not return_dict: return (sample,) return UNet1DOutput(sample=sample) ================================================ FILE: src/diffusers/models/unets/unet_1d_blocks.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from ..activations import get_activation from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims class DownResnetBlock1D(nn.Module): def __init__( self, in_channels: int, out_channels: Optional[int] = None, num_layers: int = 1, conv_shortcut: bool = False, temb_channels: int = 32, groups: int = 32, groups_out: Optional[int] = None, non_linearity: Optional[str] = None, time_embedding_norm: str = "default", output_scale_factor: float = 1.0, add_downsample: bool = True, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.time_embedding_norm = time_embedding_norm self.add_downsample = add_downsample self.output_scale_factor = output_scale_factor if groups_out is None: groups_out = groups # there will always be at least one resnet resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] for _ in range(num_layers): resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) self.resnets = nn.ModuleList(resnets) if non_linearity is None: self.nonlinearity = None else: self.nonlinearity = get_activation(non_linearity) self.downsample = None if add_downsample: self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: output_states = () hidden_states = self.resnets[0](hidden_states, temb) for resnet in self.resnets[1:]: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.nonlinearity is not None: hidden_states = self.nonlinearity(hidden_states) if self.downsample is not None: hidden_states = self.downsample(hidden_states) return hidden_states, output_states class UpResnetBlock1D(nn.Module): def __init__( self, in_channels: int, out_channels: Optional[int] = None, num_layers: int = 1, temb_channels: int = 32, groups: int = 32, groups_out: Optional[int] = None, non_linearity: Optional[str] = None, time_embedding_norm: str = "default", output_scale_factor: float = 1.0, add_upsample: bool = True, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.time_embedding_norm = time_embedding_norm self.add_upsample = add_upsample self.output_scale_factor = output_scale_factor if groups_out is None: groups_out = groups # there will always be at least one resnet resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] for _ in range(num_layers): resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) self.resnets = nn.ModuleList(resnets) if non_linearity is None: self.nonlinearity = None else: self.nonlinearity = get_activation(non_linearity) self.upsample = None if add_upsample: self.upsample = Upsample1D(out_channels, use_conv_transpose=True) def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Optional[Tuple[torch.Tensor, ...]] = None, temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: if res_hidden_states_tuple is not None: res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) hidden_states = self.resnets[0](hidden_states, temb) for resnet in self.resnets[1:]: hidden_states = resnet(hidden_states, temb) if self.nonlinearity is not None: hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: hidden_states = self.upsample(hidden_states) return hidden_states class ValueFunctionMidBlock1D(nn.Module): def __init__(self, in_channels: int, out_channels: int, embed_dim: int): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.embed_dim = embed_dim self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) self.down1 = Downsample1D(out_channels // 2, use_conv=True) self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) self.down2 = Downsample1D(out_channels // 4, use_conv=True) def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: x = self.res1(x, temb) x = self.down1(x) x = self.res2(x, temb) x = self.down2(x) return x class MidResTemporalBlock1D(nn.Module): def __init__( self, in_channels: int, out_channels: int, embed_dim: int, num_layers: int = 1, add_downsample: bool = False, add_upsample: bool = False, non_linearity: Optional[str] = None, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.add_downsample = add_downsample # there will always be at least one resnet resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] for _ in range(num_layers): resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) self.resnets = nn.ModuleList(resnets) if non_linearity is None: self.nonlinearity = None else: self.nonlinearity = get_activation(non_linearity) self.upsample = None if add_upsample: self.upsample = Upsample1D(out_channels, use_conv=True) self.downsample = None if add_downsample: self.downsample = Downsample1D(out_channels, use_conv=True) if self.upsample and self.downsample: raise ValueError("Block cannot downsample and upsample") def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) for resnet in self.resnets[1:]: hidden_states = resnet(hidden_states, temb) if self.upsample: hidden_states = self.upsample(hidden_states) if self.downsample: self.downsample = self.downsample(hidden_states) return hidden_states class OutConv1DBlock(nn.Module): def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str): super().__init__() self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) self.final_conv1d_act = get_activation(act_fn) self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.final_conv1d_1(hidden_states) hidden_states = rearrange_dims(hidden_states) hidden_states = self.final_conv1d_gn(hidden_states) hidden_states = rearrange_dims(hidden_states) hidden_states = self.final_conv1d_act(hidden_states) hidden_states = self.final_conv1d_2(hidden_states) return hidden_states class OutValueFunctionBlock(nn.Module): def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"): super().__init__() self.final_block = nn.ModuleList( [ nn.Linear(fc_dim + embed_dim, fc_dim // 2), get_activation(act_fn), nn.Linear(fc_dim // 2, 1), ] ) def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(hidden_states.shape[0], -1) hidden_states = torch.cat((hidden_states, temb), dim=-1) for layer in self.final_block: hidden_states = layer(hidden_states) return hidden_states _kernels = { "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], "lanczos3": [ 0.003689131001010537, 0.015056144446134567, -0.03399861603975296, -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 0.44638532400131226, 0.13550527393817902, -0.066637322306633, -0.03399861603975296, 0.015056144446134567, 0.003689131001010537, ], } class Downsample1d(nn.Module): def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer("kernel", kernel_1d) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) weight[indices, indices] = kernel return F.conv1d(hidden_states, weight, stride=2) class Upsample1d(nn.Module): def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) * 2 self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer("kernel", kernel_1d) def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) weight[indices, indices] = kernel return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) class SelfAttention1d(nn.Module): def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0): super().__init__() self.channels = in_channels self.group_norm = nn.GroupNorm(1, num_channels=in_channels) self.num_heads = n_head self.query = nn.Linear(self.channels, self.channels) self.key = nn.Linear(self.channels, self.channels) self.value = nn.Linear(self.channels, self.channels) self.proj_attn = nn.Linear(self.channels, self.channels, bias=True) self.dropout = nn.Dropout(dropout_rate, inplace=True) def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) return new_projection def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states batch, channel_dim, seq = hidden_states.shape hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.transpose(1, 2) query_proj = self.query(hidden_states) key_proj = self.key(hidden_states) value_proj = self.value(hidden_states) query_states = self.transpose_for_scores(query_proj) key_states = self.transpose_for_scores(key_proj) value_states = self.transpose_for_scores(value_proj) scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) attention_probs = torch.softmax(attention_scores, dim=-1) # compute attention output hidden_states = torch.matmul(attention_probs, value_states) hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) hidden_states = hidden_states.view(new_hidden_states_shape) # compute next hidden_states hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.transpose(1, 2) hidden_states = self.dropout(hidden_states) output = hidden_states + residual return output class ResConvBlock(nn.Module): def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False): super().__init__() self.is_last = is_last self.has_conv_skip = in_channels != out_channels if self.has_conv_skip: self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) self.group_norm_1 = nn.GroupNorm(1, mid_channels) self.gelu_1 = nn.GELU() self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) if not self.is_last: self.group_norm_2 = nn.GroupNorm(1, out_channels) self.gelu_2 = nn.GELU() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states hidden_states = self.conv_1(hidden_states) hidden_states = self.group_norm_1(hidden_states) hidden_states = self.gelu_1(hidden_states) hidden_states = self.conv_2(hidden_states) if not self.is_last: hidden_states = self.group_norm_2(hidden_states) hidden_states = self.gelu_2(hidden_states) output = hidden_states + residual return output class UNetMidBlock1D(nn.Module): def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): super().__init__() out_channels = in_channels if out_channels is None else out_channels # there is always at least one resnet self.down = Downsample1d("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(out_channels, out_channels // 32), ] self.up = Upsample1d(kernel="cubic") self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.down(hidden_states) for attn, resnet in zip(self.attentions, self.resnets): hidden_states = resnet(hidden_states) hidden_states = attn(hidden_states) hidden_states = self.up(hidden_states) return hidden_states class AttnDownBlock1D(nn.Module): def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels self.down = Downsample1d("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(out_channels, out_channels // 32), ] self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.down(hidden_states) for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states) hidden_states = attn(hidden_states) return hidden_states, (hidden_states,) class DownBlock1D(nn.Module): def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels self.down = Downsample1d("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.down(hidden_states) for resnet in self.resnets: hidden_states = resnet(hidden_states) return hidden_states, (hidden_states,) class DownBlock1DNoSkip(nn.Module): def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = torch.cat([hidden_states, temb], dim=1) for resnet in self.resnets: hidden_states = resnet(hidden_states) return hidden_states, (hidden_states,) class AttnUpBlock1D(nn.Module): def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(out_channels, out_channels // 32), ] self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states) hidden_states = attn(hidden_states) hidden_states = self.up(hidden_states) return hidden_states class UpBlock1D(nn.Module): def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = in_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) for resnet in self.resnets: hidden_states = resnet(hidden_states) hidden_states = self.up(hidden_states) return hidden_states class UpBlock1DNoSkip(nn.Module): def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): super().__init__() mid_channels = in_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), ] self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) for resnet in self.resnets: hidden_states = resnet(hidden_states) return hidden_states DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip] MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D] OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock] UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip] def get_down_block( down_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_downsample: bool, ) -> DownBlockType: if down_block_type == "DownResnetBlock1D": return DownResnetBlock1D( in_channels=in_channels, num_layers=num_layers, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, ) elif down_block_type == "DownBlock1D": return DownBlock1D(out_channels=out_channels, in_channels=in_channels) elif down_block_type == "AttnDownBlock1D": return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) elif down_block_type == "DownBlock1DNoSkip": return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool ) -> UpBlockType: if up_block_type == "UpResnetBlock1D": return UpResnetBlock1D( in_channels=in_channels, num_layers=num_layers, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample, ) elif up_block_type == "UpBlock1D": return UpBlock1D(in_channels=in_channels, out_channels=out_channels) elif up_block_type == "AttnUpBlock1D": return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) elif up_block_type == "UpBlock1DNoSkip": return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) raise ValueError(f"{up_block_type} does not exist.") def get_mid_block( mid_block_type: str, num_layers: int, in_channels: int, mid_channels: int, out_channels: int, embed_dim: int, add_downsample: bool, ) -> MidBlockType: if mid_block_type == "MidResTemporalBlock1D": return MidResTemporalBlock1D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim, add_downsample=add_downsample, ) elif mid_block_type == "ValueFunctionMidBlock1D": return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) elif mid_block_type == "UNetMidBlock1D": return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) raise ValueError(f"{mid_block_type} does not exist.") def get_out_block( *, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int ) -> Optional[OutBlockType]: if out_block_type == "OutConv1DBlock": return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) elif out_block_type == "ValueFunction": return OutValueFunctionBlock(fc_dim, embed_dim, act_fn) return None ================================================ FILE: src/diffusers/models/unets/unet_2d.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass class UNet2DOutput(BaseOutput): """ The output of [`UNet2DModel`]. Args: sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output from the last layer of the model. """ sample: torch.Tensor class UNet2DModel(ModelMixin, ConfigMixin): r""" A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - 1)`. in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip sin to cos for Fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block types. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): Tuple of block output channels. layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. downsample_type (`str`, *optional*, defaults to `conv`): The downsample type for downsampling layers. Choose between "conv" and "resnet" upsample_type (`str`, *optional*, defaults to `conv`): The upsample type for upsampling layers. Choose between "conv" and "resnet" dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. attn_norm_num_groups (`int`, *optional*, defaults to `None`): If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the given number of groups. If left as `None`, the group norm layer will only be created if `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups. norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class conditioning with `class_embed_type` equal to `None`. """ @register_to_config def __init__( self, sample_size: Optional[Union[int, Tuple[int, int]]] = None, in_channels: int = 3, out_channels: int = 3, center_input_sample: bool = False, time_embedding_type: str = "positional", freq_shift: int = 0, flip_sin_to_cos: bool = True, down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), layers_per_block: int = 2, mid_block_scale_factor: float = 1, downsample_padding: int = 1, downsample_type: str = "conv", upsample_type: str = "conv", dropout: float = 0.0, act_fn: str = "silu", attention_head_dim: Optional[int] = 8, norm_num_groups: int = 32, attn_norm_num_groups: Optional[int] = None, norm_eps: float = 1e-5, resnet_time_scale_shift: str = "default", add_attention: bool = True, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, num_train_timesteps: Optional[int] = None, ): super().__init__() self.sample_size = sample_size time_embed_dim = block_out_channels[0] * 4 # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) # input self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) # time if time_embedding_type == "fourier": self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] elif time_embedding_type == "learned": self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) else: self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, downsample_type=downsample_type, dropout=dropout, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, dropout=dropout, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], resnet_groups=norm_num_groups, attn_groups=attn_norm_num_groups, add_attention=add_attention, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, dropout=dropout, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], class_labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DOutput, Tuple]: r""" The [`UNet2DModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_2d.UNet2DOutput`] instead of a plain tuple. Returns: [`~models.unets.unet_2d.UNet2DOutput`] or `tuple`: If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when doing class conditioning") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb elif self.class_embedding is None and class_labels is not None: raise ValueError("class_embedding needs to be initialized in order to use class conditioning") # 2. pre-process skip_sample = sample sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "skip_conv"): sample, res_samples, skip_sample = downsample_block( hidden_states=sample, temb=emb, skip_sample=skip_sample ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid sample = self.mid_block(sample, emb) # 5. up skip_sample = None for upsample_block in self.up_blocks: res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] if hasattr(upsample_block, "skip_conv"): sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) else: sample = upsample_block(sample, res_samples, emb) # 6. post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if skip_sample is not None: sample += skip_sample if self.config.time_embedding_type == "fourier": timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) sample = sample / timesteps if not return_dict: return (sample,) return UNet2DOutput(sample=sample) ================================================ FILE: src/diffusers/models/unets/unet_2d_blocks.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from ...utils import deprecate, is_torch_version, logging from ...utils.torch_utils import apply_freeu from ..activations import get_activation from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from ..normalization import AdaGroupNorm from ..resnet import ( Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, ResnetBlockCondNorm2D, Upsample2D, ) from ..transformers.dual_transformer_2d import DualTransformer2DModel from ..transformers.transformer_2d import Transformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_down_block( down_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_downsample: bool, resnet_eps: float, resnet_act_fn: str, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, resnet_groups: Optional[int] = None, cross_attention_dim: Optional[int] = None, downsample_padding: Optional[int] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, attention_head_dim: Optional[int] = None, downsample_type: Optional[str] = None, dropout: float = 0.0, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warning( f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": return DownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "ResnetDownsampleBlock2D": return ResnetDownsampleBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, ) elif down_block_type == "AttnDownBlock2D": if add_downsample is False: downsample_type = None else: downsample_type = downsample_type or "conv" # default to 'conv' return AttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, downsample_type=downsample_type, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, ) elif down_block_type == "SimpleCrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") return SimpleCrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnSkipDownBlock2D": return AttnSkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "DownEncoderBlock2D": return DownEncoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnDownEncoderBlock2D": return AttnDownEncoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "KDownBlock2D": return KDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) elif down_block_type == "KCrossAttnDownBlock2D": return KCrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, add_self_attention=True if not add_downsample else False, ) raise ValueError(f"{down_block_type} does not exist.") def get_mid_block( mid_block_type: str, temb_channels: int, in_channels: int, resnet_eps: float, resnet_act_fn: str, resnet_groups: int, output_scale_factor: float = 1.0, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, cross_attention_dim: Optional[int] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, mid_block_only_cross_attention: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", resnet_skip_time_act: bool = False, cross_attention_norm: Optional[str] = None, attention_head_dim: Optional[int] = 1, dropout: float = 0.0, ): if mid_block_type == "UNetMidBlock2DCrossAttn": return UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, temb_channels=temb_channels, dropout=dropout, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, output_scale_factor=output_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, resnet_groups=resnet_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": return UNetMidBlock2DSimpleCrossAttn( in_channels=in_channels, temb_channels=temb_channels, dropout=dropout, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, output_scale_factor=output_scale_factor, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type == "UNetMidBlock2D": return UNetMidBlock2D( in_channels=in_channels, temb_channels=temb_channels, dropout=dropout, num_layers=0, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, output_scale_factor=output_scale_factor, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, add_attention=False, ) elif mid_block_type is None: return None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") def get_up_block( up_block_type: str, num_layers: int, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, add_upsample: bool, resnet_eps: float, resnet_act_fn: str, resolution_idx: Optional[int] = None, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, resnet_groups: Optional[int] = None, cross_attention_dim: Optional[int] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, attention_head_dim: Optional[int] = None, upsample_type: Optional[str] = None, dropout: float = 0.0, ) -> nn.Module: # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warning( f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": return UpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "ResnetUpsampleBlock2D": return ResnetUpsampleBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, ) elif up_block_type == "SimpleCrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") return SimpleCrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": if add_upsample is False: upsample_type = None else: upsample_type = upsample_type or "conv" # default to 'conv' return AttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, dropout=dropout, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "AttnSkipUpBlock2D": return AttnSkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "UpDecoderBlock2D": return UpDecoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) elif up_block_type == "AttnUpDecoderBlock2D": return AttnUpDecoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) elif up_block_type == "KUpBlock2D": return KUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) elif up_block_type == "KCrossAttnUpBlock2D": return KCrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, ) raise ValueError(f"{up_block_type} does not exist.") class AutoencoderTinyBlock(nn.Module): """ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU blocks. Args: in_channels (`int`): The number of input channels. out_channels (`int`): The number of output channels. act_fn (`str`): ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`. Returns: `torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to `out_channels`. """ def __init__(self, in_channels: int, out_channels: int, act_fn: str): super().__init__() act_fn = get_activation(act_fn) self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), act_fn, nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), act_fn, nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), ) self.skip = ( nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) if in_channels != out_channels else nn.Identity() ) self.fuse = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fuse(self.conv(x) + self.skip(x)) class UNetMidBlock2D(nn.Module): """ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. Args: in_channels (`int`): The number of input channels. temb_channels (`int`): The number of temporal embedding channels. dropout (`float`, *optional*, defaults to 0.0): The dropout rate. num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. resnet_time_scale_shift (`str`, *optional*, defaults to `default`): The type of normalization to apply to the time embeddings. This can help to improve the performance of the model on tasks with long-range temporal dependencies. resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. resnet_groups (`int`, *optional*, defaults to 32): The number of groups to use in the group normalization layers of the resnet blocks. attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. resnet_pre_norm (`bool`, *optional*, defaults to `True`): Whether to use pre-normalization for the resnet blocks. add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. attention_head_dim (`int`, *optional*, defaults to 1): Dimension of a single attention head. The number of attention heads is determined based on this value and the number of input channels. output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. Returns: `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, height, width)`. """ def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, attn_groups: Optional[int] = None, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention if attn_groups is None: attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None # there is always at least one resnet if resnet_time_scale_shift == "spatial": resnets = [ ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ] else: resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] if attention_head_dim is None: logger.warning( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." ) attention_head_dim = in_channels for _ in range(num_layers): if self.add_attention: attentions.append( Attention( in_channels, heads=in_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=attn_groups, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) else: attentions.append(None) if resnet_time_scale_shift == "spatial": resnets.append( ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ) else: resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: hidden_states = attn(hidden_states, temb=temb) hidden_states = resnet(hidden_states, temb) return hidden_states class UNetMidBlock2DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, out_channels: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, dual_cross_attention: bool = False, use_linear_projection: bool = False, upcast_attention: bool = False, attention_type: str = "default", ): super().__init__() out_channels = out_channels or in_channels self.in_channels = in_channels self.out_channels = out_channels self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers resnet_groups_out = resnet_groups_out or resnet_groups # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, groups_out=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] for i in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) else: hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) return hidden_states class UNetMidBlock2DSimpleCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, skip_time_act: bool = False, only_cross_attention: bool = False, cross_attention_norm: Optional[str] = None, ): super().__init__() self.has_cross_attention = True self.attention_head_dim = attention_head_dim resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.num_heads = in_channels // self.attention_head_dim # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ] attentions = [] for _ in range(num_layers): processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=in_channels, cross_attention_dim=in_channels, heads=self.num_heads, dim_head=self.attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) # resnet hidden_states = resnet(hidden_states, temb) return hidden_states class AttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, downsample_padding: int = 1, downsample_type: str = "conv", ): super().__init__() resnets = [] attentions = [] self.downsample_type = downsample_type if attention_head_dim is None: logger.warning( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if downsample_type == "conv": self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) elif downsample_type == "resnet": self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, down=True, ) ] ) else: self.downsamplers = None def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, upsample_size: Optional[int] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, **cross_attention_kwargs) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: if self.downsample_type == "resnet": hidden_states = downsampler(hidden_states, temb=temb) else: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class CrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, downsample_padding: int = 1, add_downsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, additional_residuals: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") output_states = () blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: hidden_states = hidden_states + additional_residuals output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class DownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class DownEncoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels if resnet_time_scale_shift == "spatial": resnets.append( ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ) else: resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class AttnDownEncoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, ): super().__init__() resnets = [] attentions = [] if attention_head_dim is None: logger.warning( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels if resnet_time_scale_shift == "spatial": resnets.append( ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ) else: resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=None) hidden_states = attn(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class AttnSkipDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, attention_head_dim: int = 1, output_scale_factor: float = np.sqrt(2.0), add_downsample: bool = True, ): super().__init__() self.attentions = nn.ModuleList([]) self.resnets = nn.ModuleList([]) if attention_head_dim is None: logger.warning( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(in_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=32, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) if add_downsample: self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, down=True, kernel="fir", ) self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None self.downsamplers = None self.skip_conv = None def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, skip_sample: Optional[torch.Tensor] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) output_states += (hidden_states,) if self.downsamplers is not None: hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) hidden_states = self.skip_conv(skip_sample) + hidden_states output_states += (hidden_states,) return hidden_states, output_states, skip_sample class SkipDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, output_scale_factor: float = np.sqrt(2.0), add_downsample: bool = True, downsample_padding: int = 1, ): super().__init__() self.resnets = nn.ModuleList([]) for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(in_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if add_downsample: self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, down=True, kernel="fir", ) self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None self.downsamplers = None self.skip_conv = None def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, skip_sample: Optional[torch.Tensor] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) output_states = () for resnet in self.resnets: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) hidden_states = self.skip_conv(skip_sample) + hidden_states output_states += (hidden_states,) return hidden_states, output_states, skip_sample class ResnetDownsampleBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, skip_time_act: bool = False, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, down=True, ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) return hidden_states, output_states class SimpleCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, add_downsample: bool = True, skip_time_act: bool = False, only_cross_attention: bool = False, cross_attention_norm: Optional[str] = None, ): super().__init__() self.has_cross_attention = True resnets = [] attentions = [] self.attention_head_dim = attention_head_dim self.num_heads = out_channels // self.attention_head_dim for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=out_channels, cross_attention_dim=out_channels, heads=self.num_heads, dim_head=attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, down=True, ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") output_states = () if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) return hidden_states, output_states class KDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 4, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: int = 32, add_downsample: bool = False, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=groups, groups_out=groups_out, eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: # YiYi's comments- might be able to use FirDownsample2D, look into details later self.downsamplers = nn.ModuleList([KDownsample2D()]) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states, output_states class KCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, cross_attention_dim: int, dropout: float = 0.0, num_layers: int = 4, resnet_group_size: int = 32, add_downsample: bool = True, attention_head_dim: int = 64, add_self_attention: bool = False, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=groups, groups_out=groups_out, eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) attentions.append( KAttentionBlock( out_channels, out_channels // attention_head_dim, attention_head_dim, cross_attention_dim=cross_attention_dim, temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, cross_attention_norm="layer_norm", group_size=resnet_group_size, ) ) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if add_downsample: self.downsamplers = nn.ModuleList([KDownsample2D()]) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") output_states = () for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if self.downsamplers is None: output_states += (None,) else: output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states, output_states class AttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, upsample_type: str = "conv", ): super().__init__() resnets = [] attentions = [] self.upsample_type = upsample_type if attention_head_dim is None: logger.warning( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if upsample_type == "conv": self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) elif upsample_type == "resnet": self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, up=True, ) ] ) else: self.upsamplers = None self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, upsample_size: Optional[int] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: if self.upsample_type == "resnet": hidden_states = upsampler(hidden_states, temb=temb) else: hidden_states = upsampler(hidden_states) return hidden_states class CrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, add_upsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) and getattr(self, "b1", None) and getattr(self, "b2", None) ) for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] # FreeU: Only operate on the first two stages if is_freeu_enabled: hidden_states, res_hidden_states = apply_freeu( self.resolution_idx, hidden_states, res_hidden_states, s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2, ) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_upsample: bool = True, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, upsample_size: Optional[int] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) and getattr(self, "b1", None) and getattr(self, "b2", None) ) for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] # FreeU: Only operate on the first two stages if is_freeu_enabled: hidden_states, res_hidden_states = apply_freeu( self.resolution_idx, hidden_states, res_hidden_states, s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2, ) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpDecoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_upsample: bool = True, temb_channels: Optional[int] = None, ): super().__init__() resnets = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels if resnet_time_scale_shift == "spatial": resnets.append( ResnetBlockCondNorm2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ) else: resnets.append( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.resolution_idx = resolution_idx def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class AttnUpDecoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, add_upsample: bool = True, temb_channels: Optional[int] = None, ): super().__init__() resnets = [] attentions = [] if attention_head_dim is None: logger.warning( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels if resnet_time_scale_shift == "spatial": resnets.append( ResnetBlockCondNorm2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ) else: resnets.append( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.resolution_idx = resolution_idx def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=temb) hidden_states = attn(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class AttnSkipUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, attention_head_dim: int = 1, output_scale_factor: float = np.sqrt(2.0), add_upsample: bool = True, ): super().__init__() self.attentions = nn.ModuleList([]) self.resnets = nn.ModuleList([]) for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(resnet_in_channels + res_skip_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if attention_head_dim is None: logger.warning( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." ) attention_head_dim = out_channels self.attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=32, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, up=True, kernel="fir", ) self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.skip_norm = torch.nn.GroupNorm( num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True ) self.act = nn.SiLU() else: self.resnet_up = None self.skip_conv = None self.skip_norm = None self.act = None self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, skip_sample=None, *args, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = self.attentions[0](hidden_states) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) else: skip_sample = 0 if self.resnet_up is not None: skip_sample_states = self.skip_norm(hidden_states) skip_sample_states = self.act(skip_sample_states) skip_sample_states = self.skip_conv(skip_sample_states) skip_sample = skip_sample + skip_sample_states hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample class SkipUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, output_scale_factor: float = np.sqrt(2.0), add_upsample: bool = True, upsample_padding: int = 1, ): super().__init__() self.resnets = nn.ModuleList([]) for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min((resnet_in_channels + res_skip_channels) // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, up=True, kernel="fir", ) self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.skip_norm = torch.nn.GroupNorm( num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True ) self.act = nn.SiLU() else: self.resnet_up = None self.skip_conv = None self.skip_norm = None self.act = None self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, skip_sample=None, *args, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) else: skip_sample = 0 if self.resnet_up is not None: skip_sample_states = self.skip_norm(hidden_states) skip_sample_states = self.act(skip_sample_states) skip_sample_states = self.skip_conv(skip_sample_states) skip_sample = skip_sample + skip_sample_states hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample class ResnetUpsampleBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_upsample: bool = True, skip_time_act: bool = False, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, up=True, ) ] ) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, upsample_size: Optional[int] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, temb) return hidden_states class SimpleCrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, add_upsample: bool = True, skip_time_act: bool = False, only_cross_attention: bool = False, cross_attention_norm: Optional[str] = None, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.attention_head_dim = attention_head_dim self.num_heads = out_channels // self.attention_head_dim for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=out_channels, cross_attention_dim=out_channels, heads=self.num_heads, dim_head=self.attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, up=True, ) ] ) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask for resnet, attn in zip(self.resnets, self.attentions): # resnet # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, temb) return hidden_states class KUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, resolution_idx: int, dropout: float = 0.0, num_layers: int = 5, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: Optional[int] = 32, add_upsample: bool = True, ): super().__init__() resnets = [] k_in_channels = 2 * out_channels k_out_channels = in_channels num_layers = num_layers - 1 for i in range(num_layers): in_channels = k_in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=k_out_channels if (i == num_layers - 1) else out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([KUpsample2D()]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, upsample_size: Optional[int] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class KCrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, resolution_idx: int, dropout: float = 0.0, num_layers: int = 4, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: int = 32, attention_head_dim: int = 1, # attention dim_head cross_attention_dim: int = 768, add_upsample: bool = True, upcast_attention: bool = False, ): super().__init__() resnets = [] attentions = [] is_first_block = in_channels == out_channels == temb_channels is_middle_block = in_channels != out_channels add_self_attention = True if is_first_block else False self.has_cross_attention = True self.attention_head_dim = attention_head_dim # in_channels, and out_channels for the block (k-unet) k_in_channels = out_channels if is_first_block else 2 * out_channels k_out_channels = in_channels num_layers = num_layers - 1 for i in range(num_layers): in_channels = k_in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size if is_middle_block and (i == num_layers - 1): conv_2d_out_channels = k_out_channels else: conv_2d_out_channels = None resnets.append( ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, conv_2d_out_channels=conv_2d_out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) attentions.append( KAttentionBlock( k_out_channels if (i == num_layers - 1) else out_channels, k_out_channels // attention_head_dim if (i == num_layers - 1) else out_channels // attention_head_dim, attention_head_dim, cross_attention_dim=cross_attention_dim, temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, cross_attention_norm="layer_norm", upcast_attention=upcast_attention, ) ) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if add_upsample: self.upsamplers = nn.ModuleList([KUpsample2D()]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states # can potentially later be renamed to `No-feed-forward` attention class KAttentionBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. attention_bias (`bool`, *optional*, defaults to `False`): Configure if the attention layers should contain a bias parameter. upcast_attention (`bool`, *optional*, defaults to `False`): Set to `True` to upcast the attention computation to `float32`. temb_channels (`int`, *optional*, defaults to 768): The number of channels in the token embedding. add_self_attention (`bool`, *optional*, defaults to `False`): Set to `True` to add self-attention to the block. cross_attention_norm (`str`, *optional*, defaults to `None`): The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. group_size (`int`, *optional*, defaults to 32): The number of groups to separate the channels into for group normalization. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout: float = 0.0, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, upcast_attention: bool = False, temb_channels: int = 768, # for ada_group_norm add_self_attention: bool = False, cross_attention_norm: Optional[str] = None, group_size: int = 32, ): super().__init__() self.add_self_attention = add_self_attention # 1. Self-Attn if add_self_attention: self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=None, cross_attention_norm=None, ) # 2. Cross-Attn self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, cross_attention_norm=cross_attention_norm, ) def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor: return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor: return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, # TODO: mark emb as non-optional (self.norm2 requires it). # requires assessing impact of change to positional param interface. emb: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") # 1. Self-Attention if self.add_self_attention: norm_hidden_states = self.norm1(hidden_states, emb) height, weight = norm_hidden_states.shape[2:] norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=None, attention_mask=attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) hidden_states = attn_output + hidden_states # 2. Cross-Attention/None norm_hidden_states = self.norm2(hidden_states, emb) height, weight = norm_hidden_states.shape[2:] norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) hidden_states = attn_output + hidden_states return hidden_states ================================================ FILE: src/diffusers/models/unets/unet_2d_blocks_flax.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import flax.linen as nn import jax.numpy as jnp from ..attention_flax import FlaxTransformer2DModel from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D class FlaxCrossAttnDownBlock2D(nn.Module): r""" Cross Attention 2D Downsizing block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers num_attention_heads (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 split_head_dim (`bool`, *optional*, defaults to `False`): Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 num_attention_heads: int = 1 add_downsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False use_memory_efficient_attention: bool = False split_head_dim: bool = False dtype: jnp.dtype = jnp.float32 transformer_layers_per_block: int = 1 def setup(self): resnets = [] attentions = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads, depth=self.transformer_layers_per_block, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, use_memory_efficient_attention=self.use_memory_efficient_attention, split_head_dim=self.split_head_dim, dtype=self.dtype, ) attentions.append(attn_block) self.resnets = resnets self.attentions = attentions if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb, deterministic=deterministic) hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) output_states += (hidden_states,) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class FlaxDownBlock2D(nn.Module): r""" Flax 2D downsizing block Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 add_downsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, deterministic=True): output_states = () for resnet in self.resnets: hidden_states = resnet(hidden_states, temb, deterministic=deterministic) output_states += (hidden_states,) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class FlaxCrossAttnUpBlock2D(nn.Module): r""" Cross Attention 2D Upsampling block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers num_attention_heads (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsampling layer before each final output use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 split_head_dim (`bool`, *optional*, defaults to `False`): Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int prev_output_channel: int dropout: float = 0.0 num_layers: int = 1 num_attention_heads: int = 1 add_upsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False use_memory_efficient_attention: bool = False split_head_dim: bool = False dtype: jnp.dtype = jnp.float32 transformer_layers_per_block: int = 1 def setup(self): resnets = [] attentions = [] for i in range(self.num_layers): res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads, depth=self.transformer_layers_per_block, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, use_memory_efficient_attention=self.use_memory_efficient_attention, split_head_dim=self.split_head_dim, dtype=self.dtype, ) attentions.append(attn_block) self.resnets = resnets self.attentions = attentions if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUpBlock2D(nn.Module): r""" Flax 2D upsampling block Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels prev_output_channel (:obj:`int`): Output channels from the previous block dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int prev_output_channel: int dropout: float = 0.0 num_layers: int = 1 add_upsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUNetMidBlock2DCrossAttn(nn.Module): r""" Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers num_attention_heads (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 split_head_dim (`bool`, *optional*, defaults to `False`): Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dropout: float = 0.0 num_layers: int = 1 num_attention_heads: int = 1 use_linear_projection: bool = False use_memory_efficient_attention: bool = False split_head_dim: bool = False dtype: jnp.dtype = jnp.float32 transformer_layers_per_block: int = 1 def setup(self): # there is always at least one resnet resnets = [ FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout_prob=self.dropout, dtype=self.dtype, ) ] attentions = [] for _ in range(self.num_layers): attn_block = FlaxTransformer2DModel( in_channels=self.in_channels, n_heads=self.num_attention_heads, d_head=self.in_channels // self.num_attention_heads, depth=self.transformer_layers_per_block, use_linear_projection=self.use_linear_projection, use_memory_efficient_attention=self.use_memory_efficient_attention, split_head_dim=self.split_head_dim, dtype=self.dtype, ) attentions.append(attn_block) res_block = FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets self.attentions = attentions def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) return hidden_states ================================================ FILE: src/diffusers/models/unets/unet_2d_condition.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ..activations import get_activation from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, FusedAttnProcessor2_0, ) from ..embeddings import ( GaussianFourierProjection, GLIGENTextBoundingboxProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from ..modeling_utils import ModelMixin from .unet_2d_blocks import ( get_down_block, get_mid_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet2DConditionOutput(BaseOutput): """ The output of [`UNet2DConditionModel`]. Args: sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.Tensor = None class UNet2DConditionModel( ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin ): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, dropout: float = 0.0, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs self._check_config( down_block_types=down_block_types, up_block_types=up_block_types, only_cross_attention=only_cross_attention, block_out_channels=block_out_channels, layers_per_block=layers_per_block, cross_attention_dim=cross_attention_dim, transformer_layers_per_block=transformer_layers_per_block, reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, attention_head_dim=attention_head_dim, num_attention_heads=num_attention_heads, ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim, timestep_input_dim = self._set_time_proj( time_embedding_type, block_out_channels=block_out_channels, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift, time_embedding_dim=time_embedding_dim, ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) self._set_encoder_hid_proj( encoder_hid_dim_type, cross_attention_dim=cross_attention_dim, encoder_hid_dim=encoder_hid_dim, ) # class embedding self._set_class_embedding( class_embed_type, act_fn=act_fn, num_class_embeds=num_class_embeds, projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, time_embed_dim=time_embed_dim, timestep_input_dim=timestep_input_dim, ) self._set_add_embedding( addition_embed_type, addition_embed_type_num_heads=addition_embed_type_num_heads, addition_time_embed_dim=addition_time_embed_dim, cross_attention_dim=cross_attention_dim, encoder_hid_dim=encoder_hid_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift, projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, time_embed_dim=time_embed_dim, ) if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, ) self.down_blocks.append(down_block) # mid self.mid_block = get_mid_block( mid_block_type, temb_channels=blocks_time_embed_dim, in_channels=block_out_channels[-1], resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, output_scale_factor=mid_block_scale_factor, transformer_layers_per_block=transformer_layers_per_block[-1], num_attention_heads=num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, mid_block_only_cross_attention=mid_block_only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[-1], dropout=dropout, ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = ( list(reversed(transformer_layers_per_block)) if reverse_transformer_layers_per_block is None else reverse_transformer_layers_per_block ) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resolution_idx=i, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) def _check_config( self, down_block_types: Tuple[str], up_block_types: Tuple[str], only_cross_attention: Union[bool, Tuple[bool]], block_out_channels: Tuple[int], layers_per_block: Union[int, Tuple[int]], cross_attention_dim: Union[int, Tuple[int]], transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], reverse_transformer_layers_per_block: bool, attention_head_dim: int, num_attention_heads: Optional[Union[int, Tuple[int]]], ): if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: for layer_number_per_block in transformer_layers_per_block: if isinstance(layer_number_per_block, list): raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") def _set_time_proj( self, time_embedding_type: str, block_out_channels: int, flip_sin_to_cos: bool, freq_shift: float, time_embedding_dim: int, ) -> Tuple[int, int]: if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) return time_embed_dim, timestep_input_dim def _set_encoder_hid_proj( self, encoder_hid_dim_type: Optional[str], cross_attention_dim: Union[int, Tuple[int]], encoder_hid_dim: Optional[int], ): if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None def _set_class_embedding( self, class_embed_type: Optional[str], act_fn: str, num_class_embeds: Optional[int], projection_class_embeddings_input_dim: Optional[int], time_embed_dim: int, timestep_input_dim: int, ): if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None def _set_add_embedding( self, addition_embed_type: str, addition_embed_type_num_heads: int, addition_time_embed_dim: Optional[int], flip_sin_to_cos: bool, freq_shift: float, cross_attention_dim: Optional[int], encoder_hid_dim: Optional[int], projection_class_embeddings_input_dim: Optional[int], time_embed_dim: int, ): if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): if attention_type in ["gated", "gated-text-image"]: positive_len = 768 if isinstance(cross_attention_dim, int): positive_len = cross_attention_dim elif isinstance(cross_attention_dim, (list, tuple)): positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" self.position_net = GLIGENTextBoundingboxProjection( positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stage blocks where they are being applied. Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. Args: s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ for i, upsample_block in enumerate(self.up_blocks): setattr(upsample_block, "s1", s1) setattr(upsample_block, "s2", s2) setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) def disable_freeu(self): """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): for k in freeu_keys: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def get_time_embed( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] ) -> Optional[torch.Tensor]: timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) return t_emb def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: class_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) return class_emb def get_aug_embed( self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] ) -> Optional[torch.Tensor]: aug_emb = None if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb = self.add_embedding(image_embs, hint) return aug_emb def process_encoder_hidden_states( self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] ) -> torch.Tensor: if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds) encoder_hidden_states = (encoder_hidden_states, image_embeds) return encoder_hidden_states def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed through the `self.time_embedding` layer to obtain the timestep embeddings. attention_mask (`torch.Tensor`, *optional*, defaults to `None`): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): A tuple of tensors that if specified are added to the residuals of down unet blocks. mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None for dim in sample.shape[-2:]: if dim % default_overall_up_factor != 0: # Forward upsample size to force interpolation output size. forward_upsample_size = True break # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time t_emb = self.get_time_embed(sample=sample, timestep=timestep) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) if class_emb is not None: if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb aug_emb = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) encoder_hidden_states = self.process_encoder_hidden_states( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) # 2. pre-process sample = self.conv_in(sample) # 2.5 GLIGEN position net if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. if cross_attention_kwargs is not None: cross_attention_kwargs = cross_attention_kwargs.copy() lora_scale = cross_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None # maintain backward compatibility for legacy usage, where # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: deprecate( "T2I should not use down_block_additional_residuals", "1.3.0", "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", standard_warn=False, ) down_intrablock_additional_residuals = down_block_additional_residuals is_adapter = True down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_intrablock_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: sample = self.mid_block(sample, emb) # To support T2I-Adapter-XL if ( is_adapter and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape ): sample += down_intrablock_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) ================================================ FILE: src/diffusers/models/unets/unet_2d_condition_flax.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Optional, Tuple, Union import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from ...configuration_utils import ConfigMixin, flax_register_to_config from ...utils import BaseOutput from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from ..modeling_flax_utils import FlaxModelMixin from .unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxCrossAttnUpBlock2D, FlaxDownBlock2D, FlaxUNetMidBlock2DCrossAttn, FlaxUpBlock2D, ) @flax.struct.dataclass class FlaxUNet2DConditionOutput(BaseOutput): """ The output of [`FlaxUNet2DConditionModel`]. Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: jnp.ndarray @flax_register_to_config class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its general usage and behavior. Inherent JAX features such as the following are supported: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: sample_size (`int`, *optional*): The size of the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): The tuple of upsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int` or `Tuple[int]`, *optional*): The number of attention heads. cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682). split_head_dim (`bool`, *optional*, defaults to `False`): Whether to split the head dimension into a new axis for the self-attention computation. In most cases, enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. """ sample_size: int = 32 in_channels: int = 4 out_channels: int = 4 down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ) up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn" only_cross_attention: Union[bool, Tuple[bool]] = False block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: Union[int, Tuple[int, ...]] = 8 num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 flip_sin_to_cos: bool = True freq_shift: int = 0 use_memory_efficient_attention: bool = False split_head_dim: bool = False transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1 addition_embed_type: Optional[str] = None addition_time_embed_dim: Optional[int] = None addition_embed_type_num_heads: int = 64 projection_class_embeddings_input_dim: Optional[int] = None def init_weights(self, rng: jax.Array) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} added_cond_kwargs = None if self.addition_embed_type == "text_time": # we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner # or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim` is_refiner = ( 5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim == self.config.projection_class_embeddings_input_dim ) num_micro_conditions = 5 if is_refiner else 6 text_embeds_dim = self.config.projection_class_embeddings_input_dim - ( num_micro_conditions * self.config.addition_time_embed_dim ) time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim time_ids_dims = time_ids_channels // self.addition_time_embed_dim added_cond_kwargs = { "text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32), "time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32), } return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] def setup(self) -> None: block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 if self.num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = self.num_attention_heads or self.attention_head_dim # input self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # time self.time_proj = FlaxTimesteps( block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift ) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) only_cross_attention = self.only_cross_attention if isinstance(only_cross_attention, bool): only_cross_attention = (only_cross_attention,) * len(self.down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(self.down_block_types) # transformer layers per block transformer_layers_per_block = self.transformer_layers_per_block if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types) # addition embed types if self.addition_embed_type is None: self.add_embedding = None elif self.addition_embed_type == "text_time": if self.addition_time_embed_dim is None: raise ValueError( f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None" ) self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift) self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) else: raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.") # down down_blocks = [] output_channel = block_out_channels[0] for i, down_block_type in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 if down_block_type == "CrossAttnDownBlock2D": down_block = FlaxCrossAttnDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, transformer_layers_per_block=transformer_layers_per_block[i], num_attention_heads=num_attention_heads[i], add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], use_memory_efficient_attention=self.use_memory_efficient_attention, split_head_dim=self.split_head_dim, dtype=self.dtype, ) else: down_block = FlaxDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) self.down_blocks = down_blocks # mid if self.config.mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], dropout=self.dropout, num_attention_heads=num_attention_heads[-1], transformer_layers_per_block=transformer_layers_per_block[-1], use_linear_projection=self.use_linear_projection, use_memory_efficient_attention=self.use_memory_efficient_attention, split_head_dim=self.split_head_dim, dtype=self.dtype, ) elif self.config.mid_block_type is None: self.mid_block = None else: raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}") # up up_blocks = [] reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) for i, up_block_type in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] is_final_block = i == len(block_out_channels) - 1 if up_block_type == "CrossAttnUpBlock2D": up_block = FlaxCrossAttnUpBlock2D( in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], num_attention_heads=reversed_num_attention_heads[i], add_upsample=not is_final_block, dropout=self.dropout, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], use_memory_efficient_attention=self.use_memory_efficient_attention, split_head_dim=self.split_head_dim, dtype=self.dtype, ) else: up_block = FlaxUpBlock2D( in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, add_upsample=not is_final_block, dropout=self.dropout, dtype=self.dtype, ) up_blocks.append(up_block) prev_output_channel = output_channel self.up_blocks = up_blocks # out self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.conv_out = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__( self, sample: jnp.ndarray, timesteps: Union[jnp.ndarray, float, int], encoder_hidden_states: jnp.ndarray, added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None, down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None, mid_block_additional_residual: Optional[jnp.ndarray] = None, return_dict: bool = True, train: bool = False, ) -> Union[FlaxUNet2DConditionOutput, Tuple[jnp.ndarray]]: r""" Args: sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): A tuple of tensors that if specified are added to the residuals of down unet blocks. mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 1. time if not isinstance(timesteps, jnp.ndarray): timesteps = jnp.array([timesteps], dtype=jnp.int32) elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: timesteps = timesteps.astype(dtype=jnp.float32) timesteps = jnp.expand_dims(timesteps, 0) t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) # additional embeddings aug_emb = None if self.addition_embed_type == "text_time": if added_cond_kwargs is None: raise ValueError( f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`" ) text_embeds = added_cond_kwargs.get("text_embeds") if text_embeds is None: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") if time_ids is None: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) # compute time embeds time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256) time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1)) add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1) aug_emb = self.add_embedding(add_embeds) t_emb = t_emb + aug_emb if aug_emb is not None else t_emb # 2. pre-process sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for down_block in self.down_blocks: if isinstance(down_block, FlaxCrossAttnDownBlock2D): sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) else: sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample += down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) if mid_block_additional_residual is not None: sample += mid_block_additional_residual # 5. up for up_block in self.up_blocks: res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)] if isinstance(up_block, FlaxCrossAttnUpBlock2D): sample = up_block( sample, temb=t_emb, encoder_hidden_states=encoder_hidden_states, res_hidden_states_tuple=res_samples, deterministic=not train, ) else: sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) # 6. post-process sample = self.conv_norm_out(sample) sample = nn.silu(sample) sample = self.conv_out(sample) sample = jnp.transpose(sample, (0, 3, 1, 2)) if not return_dict: return (sample,) return FlaxUNet2DConditionOutput(sample=sample) ================================================ FILE: src/diffusers/models/unets/unet_3d_blocks.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Tuple, Union import torch from torch import nn from ...utils import deprecate, is_torch_version, logging from ...utils.torch_utils import apply_freeu from ..attention import Attention from ..resnet import ( Downsample2D, ResnetBlock2D, SpatioTemporalResBlock, TemporalConvLayer, Upsample2D, ) from ..transformers.transformer_2d import Transformer2DModel from ..transformers.transformer_temporal import ( TransformerSpatioTemporalModel, TransformerTemporalModel, ) from .unet_motion_model import ( CrossAttnDownBlockMotion, CrossAttnUpBlockMotion, DownBlockMotion, UNetMidBlockCrossAttnMotion, UpBlockMotion, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name class DownBlockMotion(DownBlockMotion): def __init__(self, *args, **kwargs): deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead." deprecate("DownBlockMotion", "1.0.0", deprecation_message) super().__init__(*args, **kwargs) class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion): def __init__(self, *args, **kwargs): deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead." deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message) super().__init__(*args, **kwargs) class UpBlockMotion(UpBlockMotion): def __init__(self, *args, **kwargs): deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead." deprecate("UpBlockMotion", "1.0.0", deprecation_message) super().__init__(*args, **kwargs) class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion): def __init__(self, *args, **kwargs): deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead." deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message) super().__init__(*args, **kwargs) class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion): def __init__(self, *args, **kwargs): deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead." deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message) super().__init__(*args, **kwargs) def get_down_block( down_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_downsample: bool, resnet_eps: float, resnet_act_fn: str, num_attention_heads: int, resnet_groups: Optional[int] = None, cross_attention_dim: Optional[int] = None, downsample_padding: Optional[int] = None, dual_cross_attention: bool = False, use_linear_projection: bool = True, only_cross_attention: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, transformer_layers_per_block: Union[int, Tuple[int]] = 1, temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, dropout: float = 0.0, ) -> Union[ "DownBlock3D", "CrossAttnDownBlock3D", "DownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", ]: if down_block_type == "DownBlock3D": return DownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, dropout=dropout, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") return CrossAttnDownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, dropout=dropout, ) elif down_block_type == "DownBlockSpatioTemporal": # added for SDV return DownBlockSpatioTemporal( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, ) elif down_block_type == "CrossAttnDownBlockSpatioTemporal": # added for SDV if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal") return CrossAttnDownBlockSpatioTemporal( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, add_downsample=add_downsample, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type: str, num_layers: int, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, add_upsample: bool, resnet_eps: float, resnet_act_fn: str, num_attention_heads: int, resolution_idx: Optional[int] = None, resnet_groups: Optional[int] = None, cross_attention_dim: Optional[int] = None, dual_cross_attention: bool = False, use_linear_projection: bool = True, only_cross_attention: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", temporal_num_attention_heads: int = 8, temporal_cross_attention_dim: Optional[int] = None, temporal_max_seq_length: int = 32, transformer_layers_per_block: Union[int, Tuple[int]] = 1, temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, dropout: float = 0.0, ) -> Union[ "UpBlock3D", "CrossAttnUpBlock3D", "UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", ]: if up_block_type == "UpBlock3D": return UpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, dropout=dropout, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") return CrossAttnUpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, dropout=dropout, ) elif up_block_type == "UpBlockSpatioTemporal": # added for SDV return UpBlockSpatioTemporal( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, add_upsample=add_upsample, ) elif up_block_type == "CrossAttnUpBlockSpatioTemporal": # added for SDV if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal") return CrossAttnUpBlockSpatioTemporal( in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, add_upsample=add_upsample, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, resolution_idx=resolution_idx, ) raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlock3DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, dual_cross_attention: bool = False, use_linear_projection: bool = True, upcast_attention: bool = False, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] temp_convs = [ TemporalConvLayer( in_channels, in_channels, dropout=0.1, norm_num_groups=resnet_groups, ) ] attentions = [] temp_attentions = [] for _ in range(num_layers): attentions.append( Transformer2DModel( in_channels // num_attention_heads, num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) ) temp_attentions.append( TransformerTemporalModel( in_channels // num_attention_heads, num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( in_channels, in_channels, dropout=0.1, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, num_frames: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) for attn, temp_attn, resnet, temp_conv in zip( self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] ): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) return hidden_states class CrossAttnDownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, downsample_padding: int = 1, add_downsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, ): super().__init__() resnets = [] attentions = [] temp_attentions = [] temp_convs = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, ) ) attentions.append( Transformer2DModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) temp_attentions.append( TransformerTemporalModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op", ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, num_frames: int = 1, cross_attention_kwargs: Dict[str, Any] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, temp_conv, attn, temp_attn in zip( self.resnets, self.temp_convs, self.attentions, self.temp_attentions ): hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class DownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, ): super().__init__() resnets = [] temp_convs = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op", ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, num_frames: int = 1, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: output_states = () for resnet, temp_conv in zip(self.resnets, self.temp_convs): hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class CrossAttnUpBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, add_upsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, resolution_idx: Optional[int] = None, ): super().__init__() resnets = [] temp_convs = [] attentions = [] temp_attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, ) ) attentions.append( Transformer2DModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) temp_attentions.append( TransformerTemporalModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, num_frames: int = 1, cross_attention_kwargs: Dict[str, Any] = None, ) -> torch.Tensor: is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) and getattr(self, "b1", None) and getattr(self, "b2", None) ) # TODO(Patrick, William) - attention mask is not used for resnet, temp_conv, attn, temp_attn in zip( self.resnets, self.temp_convs, self.attentions, self.temp_attentions ): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] # FreeU: Only operate on the first two stages if is_freeu_enabled: hidden_states, res_hidden_states = apply_freeu( self.resolution_idx, hidden_states, res_hidden_states, s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2, ) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpBlock3D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_upsample: bool = True, resolution_idx: Optional[int] = None, ): super().__init__() resnets = [] temp_convs = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, upsample_size: Optional[int] = None, num_frames: int = 1, ) -> torch.Tensor: is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) and getattr(self, "b1", None) and getattr(self, "b2", None) ) for resnet, temp_conv in zip(self.resnets, self.temp_convs): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] # FreeU: Only operate on the first two stages if is_freeu_enabled: hidden_states, res_hidden_states = apply_freeu( self.resolution_idx, hidden_states, res_hidden_states, s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2, ) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class MidBlockTemporalDecoder(nn.Module): def __init__( self, in_channels: int, out_channels: int, attention_head_dim: int = 512, num_layers: int = 1, upcast_attention: bool = False, ): super().__init__() resnets = [] attentions = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( SpatioTemporalResBlock( in_channels=input_channels, out_channels=out_channels, temb_channels=None, eps=1e-6, temporal_eps=1e-5, merge_factor=0.0, merge_strategy="learned", switch_spatial_to_temporal_mix=True, ) ) attentions.append( Attention( query_dim=in_channels, heads=in_channels // attention_head_dim, dim_head=attention_head_dim, eps=1e-6, upcast_attention=upcast_attention, norm_num_groups=32, bias=True, residual_connection=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.Tensor, image_only_indicator: torch.Tensor, ): hidden_states = self.resnets[0]( hidden_states, image_only_indicator=image_only_indicator, ) for resnet, attn in zip(self.resnets[1:], self.attentions): hidden_states = attn(hidden_states) hidden_states = resnet( hidden_states, image_only_indicator=image_only_indicator, ) return hidden_states class UpBlockTemporalDecoder(nn.Module): def __init__( self, in_channels: int, out_channels: int, num_layers: int = 1, add_upsample: bool = True, ): super().__init__() resnets = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( SpatioTemporalResBlock( in_channels=input_channels, out_channels=out_channels, temb_channels=None, eps=1e-6, temporal_eps=1e-5, merge_factor=0.0, merge_strategy="learned", switch_spatial_to_temporal_mix=True, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None def forward( self, hidden_states: torch.Tensor, image_only_indicator: torch.Tensor, ) -> torch.Tensor: for resnet in self.resnets: hidden_states = resnet( hidden_states, image_only_indicator=image_only_indicator, ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class UNetMidBlockSpatioTemporal(nn.Module): def __init__( self, in_channels: int, temb_channels: int, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, num_attention_heads: int = 1, cross_attention_dim: int = 1280, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers # there is always at least one resnet resnets = [ SpatioTemporalResBlock( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=1e-5, ) ] attentions = [] for i in range(num_layers): attentions.append( TransformerSpatioTemporalModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, ) ) resnets.append( SpatioTemporalResBlock( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=1e-5, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.resnets[0]( hidden_states, temb, image_only_indicator=image_only_indicator, ) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, **ckpt_kwargs, ) else: hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] hidden_states = resnet( hidden_states, temb, image_only_indicator=image_only_indicator, ) return hidden_states class DownBlockSpatioTemporal(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, num_layers: int = 1, add_downsample: bool = True, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( SpatioTemporalResBlock( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=1e-5, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, name="op", ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, use_reentrant=False, ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, ) else: hidden_states = resnet( hidden_states, temb, image_only_indicator=image_only_indicator, ) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class CrossAttnDownBlockSpatioTemporal(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, num_attention_heads: int = 1, cross_attention_dim: int = 1280, add_downsample: bool = True, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( SpatioTemporalResBlock( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=1e-6, ) ) attentions.append( TransformerSpatioTemporalModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=1, name="op", ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: output_states = () blocks = list(zip(self.resnets, self.attentions)) for resnet, attn in blocks: if self.training and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] else: hidden_states = resnet( hidden_states, temb, image_only_indicator=image_only_indicator, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class UpBlockSpatioTemporal(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, resolution_idx: Optional[int] = None, num_layers: int = 1, resnet_eps: float = 1e-6, add_upsample: bool = True, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( SpatioTemporalResBlock( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, use_reentrant=False, ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, ) else: hidden_states = resnet( hidden_states, temb, image_only_indicator=image_only_indicator, ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class CrossAttnUpBlockSpatioTemporal(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, resolution_idx: Optional[int] = None, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, num_attention_heads: int = 1, cross_attention_dim: int = 1280, add_upsample: bool = True, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( SpatioTemporalResBlock( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, ) ) attentions.append( TransformerSpatioTemporalModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, image_only_indicator, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] else: hidden_states = resnet( hidden_states, temb, image_only_indicator=image_only_indicator, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states ================================================ FILE: src/diffusers/models/unets/unet_3d_condition.py ================================================ # Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. # Copyright 2024 The ModelScope Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...utils import BaseOutput, logging from ..activations import get_activation from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, FusedAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..transformers.transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, DownBlock3D, UNetMidBlock3DCrossAttn, UpBlock3D, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet3DConditionOutput(BaseOutput): """ The output of [`UNet3DConditionModel`]. Args: sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.Tensor class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. """ _supports_gradient_checkpointing = False @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), up_block_types: Tuple[str, ...] = ( "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", ), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 64, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, time_cond_proj_dim: Optional[int] = None, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise NotImplementedError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) # input conv_in_kernel = 3 conv_out_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], True, 0) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim, ) self.transformer_in = TransformerTemporalModel( num_attention_heads=8, attention_head_dim=attention_head_dim, in_channels=block_out_channels[0], num_layers=1, norm_num_groups=norm_num_groups, ) # class embedding self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=False, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock3DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=False, resolution_idx=i, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation("silu") else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) def disable_forward_chunking(self): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stage blocks where they are being applied. Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. Args: s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ for i, upsample_block in enumerate(self.up_blocks): setattr(upsample_block, "s1", s1) setattr(upsample_block, "s2", s2) setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu def disable_freeu(self): """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): for k in freeu_keys: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]: r""" The [`UNet3DConditionModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed through the `self.time_embedding` layer to obtain the timestep embeddings. attention_mask (`torch.Tensor`, *optional*, defaults to `None`): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): A tuple of tensors that if specified are added to the residuals of down unet blocks. mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. Returns: [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML num_frames = sample.shape[2] timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) emb = emb.repeat_interleave(repeats=num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) # 2. pre-process sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) sample = self.transformer_in( sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, num_frames=num_frames, ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) # reshape to (batch, channel, framerate, width, height) sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample) ================================================ FILE: src/diffusers/models/unets/unet_i2vgen_xl.py ================================================ # Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...utils import logging from ..activations import get_activation from ..attention import Attention, FeedForward from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, FusedAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..transformers.transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, DownBlock3D, UNetMidBlock3DCrossAttn, UpBlock3D, get_down_block, get_up_block, ) from .unet_3d_condition import UNet3DConditionOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class I2VGenXLTransformerTemporalEncoder(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, activation_fn: str = "geglu", upcast_attention: bool = False, ff_inner_dim: Optional[int] = None, dropout: int = 0.0, ): super().__init__() self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=False, upcast_attention=upcast_attention, out_bias=True, ) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=False, inner_dim=ff_inner_dim, bias=True, ) def forward( self, hidden_states: torch.Tensor, ) -> torch.Tensor: norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) hidden_states = attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) ff_output = self.ff(hidden_states) hidden_states = ff_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) return hidden_states class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample-shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 64): Attention head dim. num_attention_heads (`int`, *optional*): The number of attention heads. """ _supports_gradient_checkpointing = False @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), up_block_types: Tuple[str, ...] = ( "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", ), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), layers_per_block: int = 2, norm_num_groups: Optional[int] = 32, cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 64, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, ): super().__init__() # When we first integrated the UNet into the library, we didn't have `attention_head_dim`. As a consequence # of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This # is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below. # This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it # without running proper depcrecation cycles for the {down,mid,up} blocks which are a # part of the public API. num_attention_heads = attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) # input self.conv_in = nn.Conv2d(in_channels + in_channels, block_out_channels[0], kernel_size=3, padding=1) self.transformer_in = TransformerTemporalModel( num_attention_heads=8, attention_head_dim=num_attention_heads, in_channels=block_out_channels[0], num_layers=1, norm_num_groups=norm_num_groups, ) # image embedding self.image_latents_proj_in = nn.Sequential( nn.Conv2d(4, in_channels * 4, 3, padding=1), nn.SiLU(), nn.Conv2d(in_channels * 4, in_channels * 4, 3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(in_channels * 4, in_channels, 3, stride=1, padding=1), ) self.image_latents_temporal_encoder = I2VGenXLTransformerTemporalEncoder( dim=in_channels, num_attention_heads=2, ff_inner_dim=in_channels * 4, attention_head_dim=in_channels, activation_fn="gelu", ) self.image_latents_context_embedding = nn.Sequential( nn.Conv2d(4, in_channels * 8, 3, padding=1), nn.SiLU(), nn.AdaptiveAvgPool2d((32, 32)), nn.Conv2d(in_channels * 8, in_channels * 16, 3, stride=2, padding=1), nn.SiLU(), nn.Conv2d(in_channels * 16, cross_attention_dim, 3, stride=2, padding=1), ) # other embeddings -- time, context, fps, etc. time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], True, 0) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn="silu") self.context_embedding = nn.Sequential( nn.Linear(cross_attention_dim, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, cross_attention_dim * in_channels), ) self.fps_embedding = nn.Sequential( nn.Linear(timestep_input_dim, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) ) # blocks self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=1e-05, resnet_act_fn="silu", resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], downsample_padding=1, dual_cross_attention=False, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock3DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=1e-05, resnet_act_fn="silu", output_scale_factor=1, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=add_upsample, resnet_eps=1e-05, resnet_act_fn="silu", resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=False, resolution_idx=i, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-05) self.conv_act = get_activation("silu") self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking def disable_forward_chunking(self): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stage blocks where they are being applied. Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. Args: s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ for i, upsample_block in enumerate(self.up_blocks): setattr(upsample_block, "s1", s1) setattr(upsample_block, "s2", s2) setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu def disable_freeu(self): """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): for k in freeu_keys: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], fps: torch.Tensor, image_latents: torch.Tensor, image_embeddings: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]: r""" The [`I2VGenXLUNet`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. fps (`torch.Tensor`): Frames per second for the video being generated. Used as a "micro-condition". image_latents (`torch.Tensor`): Image encodings from the VAE. image_embeddings (`torch.Tensor`): Projection embeddings of the conditioning image computed with a vision encoder. encoder_hidden_states (`torch.Tensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain tuple. Returns: [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ batch_size, channels, num_frames, height, width = sample.shape # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timesteps, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) t_emb = self.time_embedding(t_emb, timestep_cond) # 2. FPS # broadcast to batch dimension in a way that's compatible with ONNX/Core ML fps = fps.expand(fps.shape[0]) fps_emb = self.fps_embedding(self.time_proj(fps).to(dtype=self.dtype)) # 3. time + FPS embeddings. emb = t_emb + fps_emb emb = emb.repeat_interleave(repeats=num_frames, dim=0) # 4. context embeddings. # The context embeddings consist of both text embeddings from the input prompt # AND the image embeddings from the input image. For images, both VAE encodings # and the CLIP image embeddings are incorporated. # So the final `context_embeddings` becomes the query for cross-attention. context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim) context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1) image_latents_for_context_embds = image_latents[:, :, :1, :] image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape( image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2], image_latents_for_context_embds.shape[1], image_latents_for_context_embds.shape[3], image_latents_for_context_embds.shape[4], ) image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs) _batch_size, _channels, _height, _width = image_latents_context_embs.shape image_latents_context_embs = image_latents_context_embs.permute(0, 2, 3, 1).reshape( _batch_size, _height * _width, _channels ) context_emb = torch.cat([context_emb, image_latents_context_embs], dim=1) image_emb = self.context_embedding(image_embeddings) image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim) context_emb = torch.cat([context_emb, image_emb], dim=1) context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0) image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape( image_latents.shape[0] * image_latents.shape[2], image_latents.shape[1], image_latents.shape[3], image_latents.shape[4], ) image_latents = self.image_latents_proj_in(image_latents) image_latents = ( image_latents[None, :] .reshape(batch_size, num_frames, channels, height, width) .permute(0, 3, 4, 1, 2) .reshape(batch_size * height * width, num_frames, channels) ) image_latents = self.image_latents_temporal_encoder(image_latents) image_latents = image_latents.reshape(batch_size, height, width, num_frames, channels).permute(0, 4, 3, 1, 2) # 5. pre-process sample = torch.cat([sample, image_latents], dim=1) sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) sample = self.transformer_in( sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # 6. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=context_emb, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) down_block_res_samples += res_samples # 7. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=context_emb, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) # 8. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=context_emb, upsample_size=upsample_size, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, num_frames=num_frames, ) # 9. post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) # reshape to (batch, channel, framerate, width, height) sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample) ================================================ FILE: src/diffusers/models/unets/unet_kandinsky3.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Dict, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput, logging from ..attention_processor import Attention, AttentionProcessor, AttnProcessor from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class Kandinsky3UNetOutput(BaseOutput): sample: torch.Tensor = None class Kandinsky3EncoderProj(nn.Module): def __init__(self, encoder_hid_dim, cross_attention_dim): super().__init__() self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False) self.projection_norm = nn.LayerNorm(cross_attention_dim) def forward(self, x): x = self.projection_linear(x) x = self.projection_norm(x) return x class Kandinsky3UNet(ModelMixin, ConfigMixin): @register_to_config def __init__( self, in_channels: int = 4, time_embedding_dim: int = 1536, groups: int = 32, attention_head_dim: int = 64, layers_per_block: Union[int, Tuple[int]] = 3, block_out_channels: Tuple[int] = (384, 768, 1536, 3072), cross_attention_dim: Union[int, Tuple[int]] = 4096, encoder_hid_dim: int = 4096, ): super().__init__() # TODO(Yiyi): Give better name and put into config for the following 4 parameters expansion_ratio = 4 compression_ratio = 2 add_cross_attention = (False, True, True, True) add_self_attention = (False, True, True, True) out_channels = in_channels init_channels = block_out_channels[0] // 2 self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1) self.time_embedding = TimestepEmbedding( init_channels, time_embedding_dim, ) self.add_time_condition = Kandinsky3AttentionPooling( time_embedding_dim, cross_attention_dim, attention_head_dim ) self.conv_in = nn.Conv2d(in_channels, init_channels, kernel_size=3, padding=1) self.encoder_hid_proj = Kandinsky3EncoderProj(encoder_hid_dim, cross_attention_dim) hidden_dims = [init_channels] + list(block_out_channels) in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:])) text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention] num_blocks = len(block_out_channels) * [layers_per_block] layer_params = [num_blocks, text_dims, add_self_attention] rev_layer_params = map(reversed, layer_params) cat_dims = [] self.num_levels = len(in_out_dims) self.down_blocks = nn.ModuleList([]) for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate( zip(in_out_dims, *layer_params) ): down_sample = level != (self.num_levels - 1) cat_dims.append(out_dim if level != (self.num_levels - 1) else 0) self.down_blocks.append( Kandinsky3DownSampleBlock( in_dim, out_dim, time_embedding_dim, text_dim, res_block_num, groups, attention_head_dim, expansion_ratio, compression_ratio, down_sample, self_attention, ) ) self.up_blocks = nn.ModuleList([]) for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate( zip(reversed(in_out_dims), *rev_layer_params) ): up_sample = level != 0 self.up_blocks.append( Kandinsky3UpSampleBlock( in_dim, cat_dims.pop(), out_dim, time_embedding_dim, text_dim, res_block_num, groups, attention_head_dim, expansion_ratio, compression_ratio, up_sample, self_attention, ) ) self.conv_norm_out = nn.GroupNorm(groups, init_channels) self.conv_act_out = nn.SiLU() self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) if not torch.is_tensor(timestep): dtype = torch.float32 if isinstance(timestep, float) else torch.int32 timestep = torch.tensor([timestep], dtype=dtype, device=sample.device) elif len(timestep.shape) == 0: timestep = timestep[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = timestep.expand(sample.shape[0]) time_embed_input = self.time_proj(timestep).to(sample.dtype) time_embed = self.time_embedding(time_embed_input) encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) if encoder_hidden_states is not None: time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask) hidden_states = [] sample = self.conv_in(sample) for level, down_sample in enumerate(self.down_blocks): sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) if level != self.num_levels - 1: hidden_states.append(sample) for level, up_sample in enumerate(self.up_blocks): if level != 0: sample = torch.cat([sample, hidden_states.pop()], dim=1) sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) sample = self.conv_norm_out(sample) sample = self.conv_act_out(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return Kandinsky3UNetOutput(sample=sample) class Kandinsky3UpSampleBlock(nn.Module): def __init__( self, in_channels, cat_dim, out_channels, time_embed_dim, context_dim=None, num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2, up_sample=True, self_attention=True, ): super().__init__() up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1) hidden_channels = ( [(in_channels + cat_dim, in_channels)] + [(in_channels, in_channels)] * (num_blocks - 2) + [(in_channels, out_channels)] ) attentions = [] resnets_in = [] resnets_out = [] self.self_attention = self_attention self.context_dim = context_dim if self_attention: attentions.append( Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) ) else: attentions.append(nn.Identity()) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): resnets_in.append( Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution) ) if context_dim is not None: attentions.append( Kandinsky3AttentionBlock( in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio ) ) else: attentions.append(nn.Identity()) resnets_out.append( Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) ) self.attentions = nn.ModuleList(attentions) self.resnets_in = nn.ModuleList(resnets_in) self.resnets_out = nn.ModuleList(resnets_out) def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): x = resnet_in(x, time_embed) if self.context_dim is not None: x = attention(x, time_embed, context, context_mask, image_mask) x = resnet_out(x, time_embed) if self.self_attention: x = self.attentions[0](x, time_embed, image_mask=image_mask) return x class Kandinsky3DownSampleBlock(nn.Module): def __init__( self, in_channels, out_channels, time_embed_dim, context_dim=None, num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2, down_sample=True, self_attention=True, ): super().__init__() attentions = [] resnets_in = [] resnets_out = [] self.self_attention = self_attention self.context_dim = context_dim if self_attention: attentions.append( Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) ) else: attentions.append(nn.Identity()) up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]] hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): resnets_in.append( Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) ) if context_dim is not None: attentions.append( Kandinsky3AttentionBlock( out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio ) ) else: attentions.append(nn.Identity()) resnets_out.append( Kandinsky3ResNetBlock( out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution ) ) self.attentions = nn.ModuleList(attentions) self.resnets_in = nn.ModuleList(resnets_in) self.resnets_out = nn.ModuleList(resnets_out) def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): if self.self_attention: x = self.attentions[0](x, time_embed, image_mask=image_mask) for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): x = resnet_in(x, time_embed) if self.context_dim is not None: x = attention(x, time_embed, context, context_mask, image_mask) x = resnet_out(x, time_embed) return x class Kandinsky3ConditionalGroupNorm(nn.Module): def __init__(self, groups, normalized_shape, context_dim): super().__init__() self.norm = nn.GroupNorm(groups, normalized_shape, affine=False) self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape)) self.context_mlp[1].weight.data.zero_() self.context_mlp[1].bias.data.zero_() def forward(self, x, context): context = self.context_mlp(context) for _ in range(len(x.shape[2:])): context = context.unsqueeze(-1) scale, shift = context.chunk(2, dim=1) x = self.norm(x) * (scale + 1.0) + shift return x class Kandinsky3Block(nn.Module): def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None): super().__init__() self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim) self.activation = nn.SiLU() if up_resolution is not None and up_resolution: self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) else: self.up_sample = nn.Identity() padding = int(kernel_size > 1) self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) if up_resolution is not None and not up_resolution: self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) else: self.down_sample = nn.Identity() def forward(self, x, time_embed): x = self.group_norm(x, time_embed) x = self.activation(x) x = self.up_sample(x) x = self.projection(x) x = self.down_sample(x) return x class Kandinsky3ResNetBlock(nn.Module): def __init__( self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None] ): super().__init__() kernel_sizes = [1, 3, 3, 1] hidden_channel = max(in_channels, out_channels) // compression_ratio hidden_channels = ( [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)] ) self.resnet_blocks = nn.ModuleList( [ Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution) for (in_channel, out_channel), kernel_size, up_resolution in zip( hidden_channels, kernel_sizes, up_resolutions ) ] ) self.shortcut_up_sample = ( nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) if True in up_resolutions else nn.Identity() ) self.shortcut_projection = ( nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() ) self.shortcut_down_sample = ( nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) if False in up_resolutions else nn.Identity() ) def forward(self, x, time_embed): out = x for resnet_block in self.resnet_blocks: out = resnet_block(out, time_embed) x = self.shortcut_up_sample(x) x = self.shortcut_projection(x) x = self.shortcut_down_sample(x) x = x + out return x class Kandinsky3AttentionPooling(nn.Module): def __init__(self, num_channels, context_dim, head_dim=64): super().__init__() self.attention = Attention( context_dim, context_dim, dim_head=head_dim, out_dim=num_channels, out_bias=False, ) def forward(self, x, context, context_mask=None): context_mask = context_mask.to(dtype=context.dtype) context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask) return x + context.squeeze(1) class Kandinsky3AttentionBlock(nn.Module): def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4): super().__init__() self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) self.attention = Attention( num_channels, context_dim or num_channels, dim_head=head_dim, out_dim=num_channels, out_bias=False, ) hidden_channels = expansion_ratio * num_channels self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) self.feed_forward = nn.Sequential( nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False), nn.SiLU(), nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False), ) def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): height, width = x.shape[-2:] out = self.in_norm(x, time_embed) out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1) context = context if context is not None else out if context_mask is not None: context_mask = context_mask.to(dtype=context.dtype) out = self.attention(out, context, context_mask) out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width) x = x + out out = self.out_norm(x, time_embed) out = self.feed_forward(out) x = x + out return x ================================================ FILE: src/diffusers/models/unets/unet_motion_model.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import BaseOutput, deprecate, is_torch_version, logging from ...utils.torch_utils import apply_freeu from ..attention import BasicTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, AttnProcessor2_0, FusedAttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D from ..transformers.dual_transformer_2d import DualTransformer2DModel from ..transformers.transformer_2d import Transformer2DModel from .unet_2d_blocks import UNetMidBlock2DCrossAttn from .unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNetMotionOutput(BaseOutput): """ The output of [`UNetMotionOutput`]. Args: sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.Tensor class AnimateDiffTransformer3D(nn.Module): """ A Transformer model for video-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. attention_bias (`bool`, *optional*): Configure if the `TransformerBlock` attention should contain a bias parameter. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported activation functions. norm_elementwise_affine (`bool`, *optional*): Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. double_self_attention (`bool`, *optional*): Configure if each `TransformerBlock` should contain two self-attention layers. positional_embeddings: (`str`, *optional*): The type of positional embeddings to apply to the sequence input before passing use. num_positional_embeddings: (`int`, *optional*): The maximum length of the sequence over which to apply positional embeddings. """ def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, activation_fn: str = "geglu", norm_elementwise_affine: bool = True, double_self_attention: bool = True, positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, double_self_attention=double_self_attention, norm_elementwise_affine=norm_elementwise_affine, positional_embeddings=positional_embeddings, num_positional_embeddings=num_positional_embeddings, ) for _ in range(num_layers) ] ) self.proj_out = nn.Linear(inner_dim, in_channels) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.LongTensor] = None, timestep: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None, num_frames: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: """ The [`AnimateDiffTransformer3D`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): Input hidden_states. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. num_frames (`int`, *optional*, defaults to 1): The number of frames to be processed per batch. This is used to reshape the hidden states. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Returns: torch.Tensor: The output tensor. """ # 1. Input batch_frames, channel, height, width = hidden_states.shape batch_size = batch_frames // num_frames residual = hidden_states hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) hidden_states = hidden_states.permute(0, 2, 1, 3, 4) hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, num_frames, channel) .permute(0, 3, 4, 1, 2) .contiguous() ) hidden_states = hidden_states.reshape(batch_frames, channel, height, width) output = hidden_states + residual return output class DownBlockMotion(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, temporal_num_attention_heads: Union[int, Tuple[int]] = 1, temporal_cross_attention_dim: Optional[int] = None, temporal_max_seq_length: int = 32, temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, temporal_double_self_attention: bool = True, ): super().__init__() resnets = [] motion_modules = [] # support for variable transformer layers per temporal block if isinstance(temporal_transformer_layers_per_block, int): temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers elif len(temporal_transformer_layers_per_block) != num_layers: raise ValueError( f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}" ) # support for variable number of attention head per temporal layers if isinstance(temporal_num_attention_heads, int): temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers elif len(temporal_num_attention_heads) != num_layers: raise ValueError( f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}" ) for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) motion_modules.append( AnimateDiffTransformer3D( num_attention_heads=temporal_num_attention_heads[i], in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=False, activation_fn="geglu", positional_embeddings="sinusoidal", num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads[i], double_self_attention=temporal_double_self_attention, ) ) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op", ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, num_frames: int = 1, *args, **kwargs, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) output_states = () blocks = zip(self.resnets, self.motion_modules) for resnet, motion_module in blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False, ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class CrossAttnDownBlockMotion(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, downsample_padding: int = 1, add_downsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", temporal_cross_attention_dim: Optional[int] = None, temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, temporal_double_self_attention: bool = True, ): super().__init__() resnets = [] attentions = [] motion_modules = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = (transformer_layers_per_block,) * num_layers elif len(transformer_layers_per_block) != num_layers: raise ValueError( f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" ) # support for variable transformer layers per temporal block if isinstance(temporal_transformer_layers_per_block, int): temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers elif len(temporal_transformer_layers_per_block) != num_layers: raise ValueError( f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" ) for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) motion_modules.append( AnimateDiffTransformer3D( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=False, activation_fn="geglu", positional_embeddings="sinusoidal", num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, double_self_attention=temporal_double_self_attention, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op", ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, num_frames: int = 1, encoder_attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, additional_residuals: Optional[torch.Tensor] = None, ): if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") output_states = () blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) for i, (resnet, attn, motion_module) in enumerate(blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, ) # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: hidden_states = hidden_states + additional_residuals output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class CrossAttnUpBlockMotion(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, add_upsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", temporal_cross_attention_dim: Optional[int] = None, temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, ): super().__init__() resnets = [] attentions = [] motion_modules = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = (transformer_layers_per_block,) * num_layers elif len(transformer_layers_per_block) != num_layers: raise ValueError( f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}" ) # support for variable transformer layers per temporal block if isinstance(temporal_transformer_layers_per_block, int): temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers elif len(temporal_transformer_layers_per_block) != num_layers: raise ValueError( f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}" ) for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) motion_modules.append( AnimateDiffTransformer3D( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=False, activation_fn="geglu", positional_embeddings="sinusoidal", num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, num_frames: int = 1, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) and getattr(self, "b1", None) and getattr(self, "b2", None) ) blocks = zip(self.resnets, self.attentions, self.motion_modules) for resnet, attn, motion_module in blocks: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] # FreeU: Only operate on the first two stages if is_freeu_enabled: hidden_states, res_hidden_states = apply_freeu( self.resolution_idx, hidden_states, res_hidden_states, s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2, ) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpBlockMotion(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_upsample: bool = True, temporal_cross_attention_dim: Optional[int] = None, temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, ): super().__init__() resnets = [] motion_modules = [] # support for variable transformer layers per temporal block if isinstance(temporal_transformer_layers_per_block, int): temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers elif len(temporal_transformer_layers_per_block) != num_layers: raise ValueError( f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" ) for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) motion_modules.append( AnimateDiffTransformer3D( num_attention_heads=temporal_num_attention_heads, in_channels=out_channels, num_layers=temporal_transformer_layers_per_block[i], norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=False, activation_fn="geglu", positional_embeddings="sinusoidal", num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, ) ) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, upsample_size=None, num_frames: int = 1, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) and getattr(self, "b1", None) and getattr(self, "b2", None) ) blocks = zip(self.resnets, self.motion_modules) for resnet, motion_module in blocks: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] # FreeU: Only operate on the first two stages if is_freeu_enabled: hidden_states, res_hidden_states = apply_freeu( self.resolution_idx, hidden_states, res_hidden_states, s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2, ) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False, ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UNetMidBlockCrossAttnMotion(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, dual_cross_attention: bool = False, use_linear_projection: bool = False, upcast_attention: bool = False, attention_type: str = "default", temporal_num_attention_heads: int = 1, temporal_cross_attention_dim: Optional[int] = None, temporal_max_seq_length: int = 32, temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = (transformer_layers_per_block,) * num_layers elif len(transformer_layers_per_block) != num_layers: raise ValueError( f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." ) # support for variable transformer layers per temporal block if isinstance(temporal_transformer_layers_per_block, int): temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers elif len(temporal_transformer_layers_per_block) != num_layers: raise ValueError( f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." ) # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] motion_modules = [] for i in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) motion_modules.append( AnimateDiffTransformer3D( num_attention_heads=temporal_num_attention_heads, attention_head_dim=in_channels // temporal_num_attention_heads, in_channels=in_channels, num_layers=temporal_transformer_layers_per_block[i], norm_num_groups=resnet_groups, cross_attention_dim=temporal_cross_attention_dim, attention_bias=False, positional_embeddings="sinusoidal", num_positional_embeddings=temporal_max_seq_length, activation_fn="geglu", ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, num_frames: int = 1, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") hidden_states = self.resnets[0](hidden_states, temb) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) for attn, resnet, motion_module in blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(motion_module), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) else: hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = motion_module( hidden_states, num_frames=num_frames, ) hidden_states = resnet(hidden_states, temb) return hidden_states class MotionModules(nn.Module): def __init__( self, in_channels: int, layers_per_block: int = 2, transformer_layers_per_block: Union[int, Tuple[int]] = 8, num_attention_heads: Union[int, Tuple[int]] = 8, attention_bias: bool = False, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", norm_num_groups: int = 32, max_seq_length: int = 32, ): super().__init__() self.motion_modules = nn.ModuleList([]) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block elif len(transformer_layers_per_block) != layers_per_block: raise ValueError( f"The number of transformer layers per block must match the number of layers per block, " f"got {layers_per_block} and {len(transformer_layers_per_block)}" ) for i in range(layers_per_block): self.motion_modules.append( AnimateDiffTransformer3D( in_channels=in_channels, num_layers=transformer_layers_per_block[i], norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, num_attention_heads=num_attention_heads, attention_head_dim=in_channels // num_attention_heads, positional_embeddings="sinusoidal", num_positional_embeddings=max_seq_length, ) ) class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin): @register_to_config def __init__( self, block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), motion_layers_per_block: Union[int, Tuple[int]] = 2, motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1, motion_mid_block_layers_per_block: int = 1, motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1, motion_num_attention_heads: Union[int, Tuple[int]] = 8, motion_norm_num_groups: int = 32, motion_max_seq_length: int = 32, use_motion_mid_block: bool = True, conv_in_channels: Optional[int] = None, ): """Container to store AnimateDiff Motion Modules Args: block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each UNet block. motion_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 2): The number of motion layers per UNet block. motion_transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple[int]]`, *optional*, defaults to 1): The number of transformer layers to use in each motion layer in each block. motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): The number of motion layers in the middle UNet block. motion_transformer_layers_per_mid_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer layers to use in each motion layer in the middle block. motion_num_attention_heads (`int` or `Tuple[int]`, *optional*, defaults to 8): The number of heads to use in each attention layer of the motion module. motion_norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use in each group normalization layer of the motion module. motion_max_seq_length (`int`, *optional*, defaults to 32): The maximum sequence length to use in the motion module. use_motion_mid_block (`bool`, *optional*, defaults to True): Whether to use a motion module in the middle of the UNet. """ super().__init__() down_blocks = [] up_blocks = [] if isinstance(motion_layers_per_block, int): motion_layers_per_block = (motion_layers_per_block,) * len(block_out_channels) elif len(motion_layers_per_block) != len(block_out_channels): raise ValueError( f"The number of motion layers per block must match the number of blocks, " f"got {len(block_out_channels)} and {len(motion_layers_per_block)}" ) if isinstance(motion_transformer_layers_per_block, int): motion_transformer_layers_per_block = (motion_transformer_layers_per_block,) * len(block_out_channels) if isinstance(motion_transformer_layers_per_mid_block, int): motion_transformer_layers_per_mid_block = ( motion_transformer_layers_per_mid_block, ) * motion_mid_block_layers_per_block elif len(motion_transformer_layers_per_mid_block) != motion_mid_block_layers_per_block: raise ValueError( f"The number of layers per mid block ({motion_mid_block_layers_per_block}) " f"must match the length of motion_transformer_layers_per_mid_block ({len(motion_transformer_layers_per_mid_block)})" ) if isinstance(motion_num_attention_heads, int): motion_num_attention_heads = (motion_num_attention_heads,) * len(block_out_channels) elif len(motion_num_attention_heads) != len(block_out_channels): raise ValueError( f"The length of the attention head number tuple in the motion module must match the " f"number of block, got {len(motion_num_attention_heads)} and {len(block_out_channels)}" ) if conv_in_channels: # input self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1) else: self.conv_in = None for i, channel in enumerate(block_out_channels): output_channel = block_out_channels[i] down_blocks.append( MotionModules( in_channels=output_channel, norm_num_groups=motion_norm_num_groups, cross_attention_dim=None, activation_fn="geglu", attention_bias=False, num_attention_heads=motion_num_attention_heads[i], max_seq_length=motion_max_seq_length, layers_per_block=motion_layers_per_block[i], transformer_layers_per_block=motion_transformer_layers_per_block[i], ) ) if use_motion_mid_block: self.mid_block = MotionModules( in_channels=block_out_channels[-1], norm_num_groups=motion_norm_num_groups, cross_attention_dim=None, activation_fn="geglu", attention_bias=False, num_attention_heads=motion_num_attention_heads[-1], max_seq_length=motion_max_seq_length, layers_per_block=motion_mid_block_layers_per_block, transformer_layers_per_block=motion_transformer_layers_per_mid_block, ) else: self.mid_block = None reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] reversed_motion_layers_per_block = list(reversed(motion_layers_per_block)) reversed_motion_transformer_layers_per_block = list(reversed(motion_transformer_layers_per_block)) reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads)) for i, channel in enumerate(reversed_block_out_channels): output_channel = reversed_block_out_channels[i] up_blocks.append( MotionModules( in_channels=output_channel, norm_num_groups=motion_norm_num_groups, cross_attention_dim=None, activation_fn="geglu", attention_bias=False, num_attention_heads=reversed_motion_num_attention_heads[i], max_seq_length=motion_max_seq_length, layers_per_block=reversed_motion_layers_per_block[i] + 1, transformer_layers_per_block=reversed_motion_transformer_layers_per_block[i], ) ) self.down_blocks = nn.ModuleList(down_blocks) self.up_blocks = nn.ModuleList(up_blocks) def forward(self, sample): pass class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): r""" A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion", "CrossAttnDownBlockMotion", "DownBlockMotion", ), up_block_types: Tuple[str, ...] = ( "UpBlockMotion", "CrossAttnUpBlockMotion", "CrossAttnUpBlockMotion", "CrossAttnUpBlockMotion", ), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None, temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None, transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None, temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1, use_linear_projection: bool = False, num_attention_heads: Union[int, Tuple[int, ...]] = 8, motion_max_seq_length: int = 32, motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8, reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None, use_motion_mid_block: bool = True, mid_block_layers: int = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None, ): super().__init__() self.sample_size = sample_size # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: for layer_number_per_block in transformer_layers_per_block: if isinstance(layer_number_per_block, list): raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") if ( isinstance(temporal_transformer_layers_per_block, list) and reverse_temporal_transformer_layers_per_block is None ): for layer_number_per_block in temporal_transformer_layers_per_block: if isinstance(layer_number_per_block, list): raise ValueError( "Must provide 'reverse_temporal_transformer_layers_per_block` if using asymmetrical motion module in UNet." ) # input conv_in_kernel = 3 conv_out_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], True, 0) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim ) if encoder_hid_dim_type is None: self.encoder_hid_proj = None if addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) # class embedding self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if isinstance(reverse_transformer_layers_per_block, int): reverse_transformer_layers_per_block = [reverse_transformer_layers_per_block] * len(down_block_types) if isinstance(temporal_transformer_layers_per_block, int): temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types) if isinstance(reverse_temporal_transformer_layers_per_block, int): reverse_temporal_transformer_layers_per_block = [reverse_temporal_transformer_layers_per_block] * len( down_block_types ) if isinstance(motion_num_attention_heads, int): motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 if down_block_type == "CrossAttnDownBlockMotion": down_block = CrossAttnDownBlockMotion( in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, num_attention_heads=num_attention_heads[i], cross_attention_dim=cross_attention_dim[i], downsample_padding=downsample_padding, add_downsample=not is_final_block, use_linear_projection=use_linear_projection, temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], ) elif down_block_type == "DownBlockMotion": down_block = DownBlockMotion( in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, num_layers=layers_per_block[i], resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, add_downsample=not is_final_block, downsample_padding=downsample_padding, temporal_num_attention_heads=motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], ) else: raise ValueError( "Invalid `down_block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" ) self.down_blocks.append(down_block) # mid if transformer_layers_per_mid_block is None: transformer_layers_per_mid_block = ( transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1 ) if use_motion_mid_block: self.mid_block = UNetMidBlockCrossAttnMotion( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, use_linear_projection=use_linear_projection, num_layers=mid_block_layers, temporal_num_attention_heads=motion_num_attention_heads[-1], temporal_max_seq_length=motion_max_seq_length, transformer_layers_per_block=transformer_layers_per_mid_block, temporal_transformer_layers_per_block=temporal_transformer_layers_per_mid_block, ) else: self.mid_block = UNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, use_linear_projection=use_linear_projection, num_layers=mid_block_layers, transformer_layers_per_block=transformer_layers_per_mid_block, ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads)) if reverse_transformer_layers_per_block is None: reverse_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) if reverse_temporal_transformer_layers_per_block is None: reverse_temporal_transformer_layers_per_block = list(reversed(temporal_transformer_layers_per_block)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False if up_block_type == "CrossAttnUpBlockMotion": up_block = CrossAttnUpBlockMotion( in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, resolution_idx=i, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reverse_transformer_layers_per_block[i], resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, num_attention_heads=reversed_num_attention_heads[i], cross_attention_dim=reversed_cross_attention_dim[i], add_upsample=add_upsample, use_linear_projection=use_linear_projection, temporal_num_attention_heads=reversed_motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], ) elif up_block_type == "UpBlockMotion": up_block = UpBlockMotion( in_channels=input_channel, prev_output_channel=prev_output_channel, out_channels=output_channel, temb_channels=time_embed_dim, resolution_idx=i, num_layers=reversed_layers_per_block[i] + 1, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, add_upsample=add_upsample, temporal_num_attention_heads=reversed_motion_num_attention_heads[i], temporal_max_seq_length=motion_max_seq_length, temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], ) else: raise ValueError( "Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`" ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = nn.SiLU() else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @classmethod def from_unet2d( cls, unet: UNet2DConditionModel, motion_adapter: Optional[MotionAdapter] = None, load_weights: bool = True, ): has_motion_adapter = motion_adapter is not None if has_motion_adapter: motion_adapter.to(device=unet.device) # check compatibility of number of blocks if len(unet.config["down_block_types"]) != len(motion_adapter.config["block_out_channels"]): raise ValueError("Incompatible Motion Adapter, got different number of blocks") # check layers compatibility for each block if isinstance(unet.config["layers_per_block"], int): expanded_layers_per_block = [unet.config["layers_per_block"]] * len(unet.config["down_block_types"]) else: expanded_layers_per_block = list(unet.config["layers_per_block"]) if isinstance(motion_adapter.config["motion_layers_per_block"], int): expanded_adapter_layers_per_block = [motion_adapter.config["motion_layers_per_block"]] * len( motion_adapter.config["block_out_channels"] ) else: expanded_adapter_layers_per_block = list(motion_adapter.config["motion_layers_per_block"]) if expanded_layers_per_block != expanded_adapter_layers_per_block: raise ValueError("Incompatible Motion Adapter, got different number of layers per block") # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 config = dict(unet.config) config["_class_name"] = cls.__name__ down_blocks = [] for down_blocks_type in config["down_block_types"]: if "CrossAttn" in down_blocks_type: down_blocks.append("CrossAttnDownBlockMotion") else: down_blocks.append("DownBlockMotion") config["down_block_types"] = down_blocks up_blocks = [] for down_blocks_type in config["up_block_types"]: if "CrossAttn" in down_blocks_type: up_blocks.append("CrossAttnUpBlockMotion") else: up_blocks.append("UpBlockMotion") config["up_block_types"] = up_blocks if has_motion_adapter: config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] config["layers_per_block"] = motion_adapter.config["motion_layers_per_block"] config["temporal_transformer_layers_per_mid_block"] = motion_adapter.config[ "motion_transformer_layers_per_mid_block" ] config["temporal_transformer_layers_per_block"] = motion_adapter.config[ "motion_transformer_layers_per_block" ] config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] # For PIA UNets we need to set the number input channels to 9 if motion_adapter.config["conv_in_channels"]: config["in_channels"] = motion_adapter.config["conv_in_channels"] # Need this for backwards compatibility with UNet2DConditionModel checkpoints if not config.get("num_attention_heads"): config["num_attention_heads"] = config["attention_head_dim"] expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) config = FrozenDict({k: config.get(k) for k in config if k in expected_kwargs or k in optional_kwargs}) config["_class_name"] = cls.__name__ model = cls.from_config(config) if not load_weights: return model # Logic for loading PIA UNets which allow the first 4 channels to be any UNet2DConditionModel conv_in weight # while the last 5 channels must be PIA conv_in weights. if has_motion_adapter and motion_adapter.config["conv_in_channels"]: model.conv_in = motion_adapter.conv_in updated_conv_in_weight = torch.cat( [unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], dim=1 ) model.conv_in.load_state_dict({"weight": updated_conv_in_weight, "bias": unet.conv_in.bias}) else: model.conv_in.load_state_dict(unet.conv_in.state_dict()) model.time_proj.load_state_dict(unet.time_proj.state_dict()) model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) if any( isinstance(proc, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)) for proc in unet.attn_processors.values() ): attn_procs = {} for name, processor in unet.attn_processors.items(): if name.endswith("attn1.processor"): attn_processor_class = ( AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor ) attn_procs[name] = attn_processor_class() else: attn_processor_class = ( IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor ) attn_procs[name] = attn_processor_class( hidden_size=processor.hidden_size, cross_attention_dim=processor.cross_attention_dim, scale=processor.scale, num_tokens=processor.num_tokens, ) for name, processor in model.attn_processors.items(): if name not in attn_procs: attn_procs[name] = processor.__class__() model.set_attn_processor(attn_procs) model.config.encoder_hid_dim_type = "ip_image_proj" model.encoder_hid_proj = unet.encoder_hid_proj for i, down_block in enumerate(unet.down_blocks): model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) if hasattr(model.down_blocks[i], "attentions"): model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict()) if model.down_blocks[i].downsamplers: model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict()) for i, up_block in enumerate(unet.up_blocks): model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict()) if hasattr(model.up_blocks[i], "attentions"): model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict()) if model.up_blocks[i].upsamplers: model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) if unet.conv_norm_out is not None: model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) if unet.conv_act is not None: model.conv_act.load_state_dict(unet.conv_act.state_dict()) model.conv_out.load_state_dict(unet.conv_out.state_dict()) if has_motion_adapter: model.load_motion_modules(motion_adapter) # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel model.to(unet.dtype) return model def freeze_unet2d_params(self) -> None: """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules unfrozen for fine tuning. """ # Freeze everything for param in self.parameters(): param.requires_grad = False # Unfreeze Motion Modules for down_block in self.down_blocks: motion_modules = down_block.motion_modules for param in motion_modules.parameters(): param.requires_grad = True for up_block in self.up_blocks: motion_modules = up_block.motion_modules for param in motion_modules.parameters(): param.requires_grad = True if hasattr(self.mid_block, "motion_modules"): motion_modules = self.mid_block.motion_modules for param in motion_modules.parameters(): param.requires_grad = True def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None: for i, down_block in enumerate(motion_adapter.down_blocks): self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict()) for i, up_block in enumerate(motion_adapter.up_blocks): self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict()) # to support older motion modules that don't have a mid_block if hasattr(self.mid_block, "motion_modules"): self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict()) def save_motion_modules( self, save_directory: str, is_main_process: bool = True, safe_serialization: bool = True, variant: Optional[str] = None, push_to_hub: bool = False, **kwargs, ) -> None: state_dict = self.state_dict() # Extract all motion modules motion_state_dict = {} for k, v in state_dict.items(): if "motion_modules" in k: motion_state_dict[k] = v adapter = MotionAdapter( block_out_channels=self.config["block_out_channels"], motion_layers_per_block=self.config["layers_per_block"], motion_norm_num_groups=self.config["norm_num_groups"], motion_num_attention_heads=self.config["motion_num_attention_heads"], motion_max_seq_length=self.config["motion_max_seq_length"], use_motion_mid_block=self.config["use_motion_mid_block"], ) adapter.load_state_dict(motion_state_dict) adapter.save_pretrained( save_directory=save_directory, is_main_process=is_main_process, safe_serialization=safe_serialization, variant=variant, push_to_hub=push_to_hub, **kwargs, ) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) def disable_forward_chunking(self) -> None: def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self) -> None: """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): module.gradient_checkpointing = value # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stage blocks where they are being applied. Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. Args: s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ for i, upsample_block in enumerate(self.up_blocks): setattr(upsample_block, "s1", s1) setattr(upsample_block, "s2", s2) setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu def disable_freeu(self) -> None: """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): for k in freeu_keys: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNetMotionOutput, Tuple[torch.Tensor]]: r""" The [`UNetMotionModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed through the `self.time_embedding` layer to obtain the timestep embeddings. attention_mask (`torch.Tensor`, *optional*, defaults to `None`): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): A tuple of tensors that if specified are added to the residuals of down unet blocks. mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_motion_model.UNetMotionOutput`] instead of a plain tuple. Returns: [`~models.unets.unet_motion_model.UNetMotionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unets.unet_motion_model.UNetMotionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML num_frames = sample.shape[2] timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) emb = emb if aug_emb is None else emb + aug_emb emb = emb.repeat_interleave(repeats=num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds) image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds] encoder_hidden_states = (encoder_hidden_states, image_embeds) # 2. pre-process sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: # To support older versions of motion modules that don't have a mid_block if hasattr(self.mid_block, "motion_modules"): sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, num_frames=num_frames, ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) # reshape to (batch, channel, framerate, width, height) sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) if not return_dict: return (sample,) return UNetMotionOutput(sample=sample) ================================================ FILE: src/diffusers/models/unets/unet_spatio_temporal_condition.py ================================================ from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...utils import BaseOutput, logging from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNetSpatioTemporalConditionOutput(BaseOutput): """ The output of [`UNetSpatioTemporalConditionModel`]. Args: sample (`torch.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.Tensor = None class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. addition_time_embed_dim: (`int`, defaults to 256): Dimension to to encode the additional time ids. projection_class_embeddings_input_dim (`int`, defaults to 768): The dimension of the projection of encoded `added_time_ids`. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], [`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): The number of attention heads. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 8, out_channels: int = 4, down_block_types: Tuple[str] = ( "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal", ), up_block_types: Tuple[str] = ( "UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", ), block_out_channels: Tuple[int] = (320, 640, 1280, 1280), addition_time_embed_dim: int = 256, projection_class_embeddings_input_dim: int = 768, layers_per_block: Union[int, Tuple[int]] = 2, cross_attention_dim: Union[int, Tuple[int]] = 1024, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20), num_frames: int = 25, ): super().__init__() self.sample_size = sample_size # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=3, padding=1, ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=1e-5, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], resnet_act_fn="silu", ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlockSpatioTemporal( block_out_channels[-1], temb_channels=blocks_time_embed_dim, transformer_layers_per_block=transformer_layers_per_block[-1], cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=1e-5, resolution_idx=i, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], resnet_act_fn="silu", ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=3, padding=1, ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors( name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor], ): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, added_time_ids: torch.Tensor, return_dict: bool = True, ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: r""" The [`UNetSpatioTemporalConditionModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. added_time_ids: (`torch.Tensor`): The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal embeddings and added to the time embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain tuple. Returns: [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML batch_size, num_frames = sample.shape[:2] timesteps = timesteps.expand(batch_size) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb) time_embeds = self.add_time_proj(added_time_ids.flatten()) time_embeds = time_embeds.reshape((batch_size, -1)) time_embeds = time_embeds.to(emb.dtype) aug_emb = self.add_embedding(time_embeds) emb = emb + aug_emb # Flatten the batch and frames dimensions # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] sample = sample.flatten(0, 1) # Repeat the embeddings num_video_frames times # emb: [batch, channels] -> [batch * frames, channels] emb = emb.repeat_interleave(num_frames, dim=0) # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) # 2. pre-process sample = self.conv_in(sample) image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, ) else: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, image_only_indicator=image_only_indicator, ) down_block_res_samples += res_samples # 4. mid sample = self.mid_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, ) # 5. up for i, upsample_block in enumerate(self.up_blocks): res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, image_only_indicator=image_only_indicator, ) # 6. post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) # 7. Reshape back to original shape sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) if not return_dict: return (sample,) return UNetSpatioTemporalConditionOutput(sample=sample) ================================================ FILE: src/diffusers/models/unets/unet_stable_cascade.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin from ...utils import BaseOutput from ..attention_processor import Attention from ..modeling_utils import ModelMixin # Copied from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common.WuerstchenLayerNorm with WuerstchenLayerNorm -> SDCascadeLayerNorm class SDCascadeLayerNorm(nn.LayerNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x): x = x.permute(0, 2, 3, 1) x = super().forward(x) return x.permute(0, 3, 1, 2) class SDCascadeTimestepBlock(nn.Module): def __init__(self, c, c_timestep, conds=[]): super().__init__() self.mapper = nn.Linear(c_timestep, c * 2) self.conds = conds for cname in conds: setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2)) def forward(self, x, t): t = t.chunk(len(self.conds) + 1, dim=1) a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) for i, c in enumerate(self.conds): ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) a, b = a + ac, b + bc return x * (1 + a) + b class SDCascadeResBlock(nn.Module): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): super().__init__() self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( nn.Linear(c + c_skip, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c), ) def forward(self, x, x_skip=None): x_res = x x = self.norm(self.depthwise(x)) if x_skip is not None: x = torch.cat([x, x_skip], dim=1) x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x + x_res # from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 class GlobalResponseNorm(nn.Module): def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * stand_div_norm) + self.beta + x class SDCascadeAttnBlock(nn.Module): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() self.self_attn = self_attn self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) def forward(self, x, kv): kv = self.kv_mapper(kv) norm_x = self.norm(x) if self.self_attn: batch_size, channel, _, _ = x.shape kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1) x = x + self.attention(norm_x, encoder_hidden_states=kv) return x class UpDownBlock2d(nn.Module): def __init__(self, in_channels, out_channels, mode, enabled=True): super().__init__() if mode not in ["up", "down"]: raise ValueError(f"{mode} not supported") interpolation = ( nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True) if enabled else nn.Identity() ) mapping = nn.Conv2d(in_channels, out_channels, kernel_size=1) self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation]) def forward(self, x): for block in self.blocks: x = block(x) return x @dataclass class StableCascadeUNetOutput(BaseOutput): sample: torch.Tensor = None class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 16, out_channels: int = 16, timestep_ratio_embedding_dim: int = 64, patch_size: int = 1, conditioning_dim: int = 2048, block_out_channels: Tuple[int] = (2048, 2048), num_attention_heads: Tuple[int] = (32, 32), down_num_layers_per_block: Tuple[int] = (8, 24), up_num_layers_per_block: Tuple[int] = (24, 8), down_blocks_repeat_mappers: Optional[Tuple[int]] = ( 1, 1, ), up_blocks_repeat_mappers: Optional[Tuple[int]] = (1, 1), block_types_per_layer: Tuple[Tuple[str]] = ( ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ), clip_text_in_channels: Optional[int] = None, clip_text_pooled_in_channels=1280, clip_image_in_channels: Optional[int] = None, clip_seq=4, effnet_in_channels: Optional[int] = None, pixel_mapper_in_channels: Optional[int] = None, kernel_size=3, dropout: Union[float, Tuple[float]] = (0.1, 0.1), self_attn: Union[bool, Tuple[bool]] = True, timestep_conditioning_type: Tuple[str] = ("sca", "crp"), switch_level: Optional[Tuple[bool]] = None, ): """ Parameters: in_channels (`int`, defaults to 16): Number of channels in the input sample. out_channels (`int`, defaults to 16): Number of channels in the output sample. timestep_ratio_embedding_dim (`int`, defaults to 64): Dimension of the projected time embedding. patch_size (`int`, defaults to 1): Patch size to use for pixel unshuffling layer conditioning_dim (`int`, defaults to 2048): Dimension of the image and text conditional embedding. block_out_channels (Tuple[int], defaults to (2048, 2048)): Tuple of output channels for each block. num_attention_heads (Tuple[int], defaults to (32, 32)): Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have attention. down_num_layers_per_block (Tuple[int], defaults to [8, 24]): Number of layers in each down block. up_num_layers_per_block (Tuple[int], defaults to [24, 8]): Number of layers in each up block. down_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]): Number of 1x1 Convolutional layers to repeat in each down block. up_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]): Number of 1x1 Convolutional layers to repeat in each up block. block_types_per_layer (Tuple[Tuple[str]], optional, defaults to ( ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock") ): Block types used in each layer of the up/down blocks. clip_text_in_channels (`int`, *optional*, defaults to `None`): Number of input channels for CLIP based text conditioning. clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280): Number of input channels for pooled CLIP text embeddings. clip_image_in_channels (`int`, *optional*): Number of input channels for CLIP based image conditioning. clip_seq (`int`, *optional*, defaults to 4): effnet_in_channels (`int`, *optional*, defaults to `None`): Number of input channels for effnet conditioning. pixel_mapper_in_channels (`int`, defaults to `None`): Number of input channels for pixel mapper conditioning. kernel_size (`int`, *optional*, defaults to 3): Kernel size to use in the block convolutional layers. dropout (Tuple[float], *optional*, defaults to (0.1, 0.1)): Dropout to use per block. self_attn (Union[bool, Tuple[bool]]): Tuple of booleans that determine whether to use self attention in a block or not. timestep_conditioning_type (Tuple[str], defaults to ("sca", "crp")): Timestep conditioning type. switch_level (Optional[Tuple[bool]], *optional*, defaults to `None`): Tuple that indicates whether upsampling or downsampling should be applied in a block """ super().__init__() if len(block_out_channels) != len(down_num_layers_per_block): raise ValueError( f"Number of elements in `down_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}" ) elif len(block_out_channels) != len(up_num_layers_per_block): raise ValueError( f"Number of elements in `up_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}" ) elif len(block_out_channels) != len(down_blocks_repeat_mappers): raise ValueError( f"Number of elements in `down_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}" ) elif len(block_out_channels) != len(up_blocks_repeat_mappers): raise ValueError( f"Number of elements in `up_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}" ) elif len(block_out_channels) != len(block_types_per_layer): raise ValueError( f"Number of elements in `block_types_per_layer` must match the length of `block_out_channels`: {len(block_out_channels)}" ) if isinstance(dropout, float): dropout = (dropout,) * len(block_out_channels) if isinstance(self_attn, bool): self_attn = (self_attn,) * len(block_out_channels) # CONDITIONING if effnet_in_channels is not None: self.effnet_mapper = nn.Sequential( nn.Conv2d(effnet_in_channels, block_out_channels[0] * 4, kernel_size=1), nn.GELU(), nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1), SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), ) if pixel_mapper_in_channels is not None: self.pixels_mapper = nn.Sequential( nn.Conv2d(pixel_mapper_in_channels, block_out_channels[0] * 4, kernel_size=1), nn.GELU(), nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1), SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), ) self.clip_txt_pooled_mapper = nn.Linear(clip_text_pooled_in_channels, conditioning_dim * clip_seq) if clip_text_in_channels is not None: self.clip_txt_mapper = nn.Linear(clip_text_in_channels, conditioning_dim) if clip_image_in_channels is not None: self.clip_img_mapper = nn.Linear(clip_image_in_channels, conditioning_dim * clip_seq) self.clip_norm = nn.LayerNorm(conditioning_dim, elementwise_affine=False, eps=1e-6) self.embedding = nn.Sequential( nn.PixelUnshuffle(patch_size), nn.Conv2d(in_channels * (patch_size**2), block_out_channels[0], kernel_size=1), SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), ) def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=True): if block_type == "SDCascadeResBlock": return SDCascadeResBlock(in_channels, c_skip, kernel_size=kernel_size, dropout=dropout) elif block_type == "SDCascadeAttnBlock": return SDCascadeAttnBlock(in_channels, conditioning_dim, nhead, self_attn=self_attn, dropout=dropout) elif block_type == "SDCascadeTimestepBlock": return SDCascadeTimestepBlock( in_channels, timestep_ratio_embedding_dim, conds=timestep_conditioning_type ) else: raise ValueError(f"Block type {block_type} not supported") # BLOCKS # -- down blocks self.down_blocks = nn.ModuleList() self.down_downscalers = nn.ModuleList() self.down_repeat_mappers = nn.ModuleList() for i in range(len(block_out_channels)): if i > 0: self.down_downscalers.append( nn.Sequential( SDCascadeLayerNorm(block_out_channels[i - 1], elementwise_affine=False, eps=1e-6), UpDownBlock2d( block_out_channels[i - 1], block_out_channels[i], mode="down", enabled=switch_level[i - 1] ) if switch_level is not None else nn.Conv2d(block_out_channels[i - 1], block_out_channels[i], kernel_size=2, stride=2), ) ) else: self.down_downscalers.append(nn.Identity()) down_block = nn.ModuleList() for _ in range(down_num_layers_per_block[i]): for block_type in block_types_per_layer[i]: block = get_block( block_type, block_out_channels[i], num_attention_heads[i], dropout=dropout[i], self_attn=self_attn[i], ) down_block.append(block) self.down_blocks.append(down_block) if down_blocks_repeat_mappers is not None: block_repeat_mappers = nn.ModuleList() for _ in range(down_blocks_repeat_mappers[i] - 1): block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1)) self.down_repeat_mappers.append(block_repeat_mappers) # -- up blocks self.up_blocks = nn.ModuleList() self.up_upscalers = nn.ModuleList() self.up_repeat_mappers = nn.ModuleList() for i in reversed(range(len(block_out_channels))): if i > 0: self.up_upscalers.append( nn.Sequential( SDCascadeLayerNorm(block_out_channels[i], elementwise_affine=False, eps=1e-6), UpDownBlock2d( block_out_channels[i], block_out_channels[i - 1], mode="up", enabled=switch_level[i - 1] ) if switch_level is not None else nn.ConvTranspose2d( block_out_channels[i], block_out_channels[i - 1], kernel_size=2, stride=2 ), ) ) else: self.up_upscalers.append(nn.Identity()) up_block = nn.ModuleList() for j in range(up_num_layers_per_block[::-1][i]): for k, block_type in enumerate(block_types_per_layer[i]): c_skip = block_out_channels[i] if i < len(block_out_channels) - 1 and j == k == 0 else 0 block = get_block( block_type, block_out_channels[i], num_attention_heads[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i], ) up_block.append(block) self.up_blocks.append(up_block) if up_blocks_repeat_mappers is not None: block_repeat_mappers = nn.ModuleList() for _ in range(up_blocks_repeat_mappers[::-1][i] - 1): block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1)) self.up_repeat_mappers.append(block_repeat_mappers) # OUTPUT self.clf = nn.Sequential( SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), nn.Conv2d(block_out_channels[0], out_channels * (patch_size**2), kernel_size=1), nn.PixelShuffle(patch_size), ) self.gradient_checkpointing = False def _set_gradient_checkpointing(self, value=False): self.gradient_checkpointing = value def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): torch.nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) if hasattr(self, "clip_txt_mapper") else None nn.init.normal_(self.clip_img_mapper.weight, std=0.02) if hasattr(self, "clip_img_mapper") else None if hasattr(self, "effnet_mapper"): nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings if hasattr(self, "pixels_mapper"): nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs nn.init.constant_(self.clf[1].weight, 0) # outputs # blocks for level_block in self.down_blocks + self.up_blocks: for block in level_block: if isinstance(block, SDCascadeResBlock): block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks[0])) elif isinstance(block, SDCascadeTimestepBlock): nn.init.constant_(block.mapper.weight, 0) def get_timestep_ratio_embedding(self, timestep_ratio, max_positions=10000): r = timestep_ratio * max_positions half_dim = self.config.timestep_ratio_embedding_dim // 2 emb = math.log(max_positions) / (half_dim - 1) emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() emb = r[:, None] * emb[None, :] emb = torch.cat([emb.sin(), emb.cos()], dim=1) if self.config.timestep_ratio_embedding_dim % 2 == 1: # zero pad emb = nn.functional.pad(emb, (0, 1), mode="constant") return emb.to(dtype=r.dtype) def get_clip_embeddings(self, clip_txt_pooled, clip_txt=None, clip_img=None): if len(clip_txt_pooled.shape) == 2: clip_txt_pool = clip_txt_pooled.unsqueeze(1) clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view( clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.config.clip_seq, -1 ) if clip_txt is not None and clip_img is not None: clip_txt = self.clip_txt_mapper(clip_txt) if len(clip_img.shape) == 2: clip_img = clip_img.unsqueeze(1) clip_img = self.clip_img_mapper(clip_img).view( clip_img.size(0), clip_img.size(1) * self.config.clip_seq, -1 ) clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) else: clip = clip_txt_pool return self.clip_norm(clip) def _down_encode(self, x, r_embed, clip): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward for down_block, downscaler, repmap in block_group: x = downscaler(x) for i in range(len(repmap) + 1): for block in down_block: if isinstance(block, SDCascadeResBlock): x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) elif isinstance(block, SDCascadeAttnBlock): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, clip, use_reentrant=False ) elif isinstance(block, SDCascadeTimestepBlock): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, r_embed, use_reentrant=False ) else: x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False) if i < len(repmap): x = repmap[i](x) level_outputs.insert(0, x) else: for down_block, downscaler, repmap in block_group: x = downscaler(x) for i in range(len(repmap) + 1): for block in down_block: if isinstance(block, SDCascadeResBlock): x = block(x) elif isinstance(block, SDCascadeAttnBlock): x = block(x, clip) elif isinstance(block, SDCascadeTimestepBlock): x = block(x, r_embed) else: x = block(x) if i < len(repmap): x = repmap[i](x) level_outputs.insert(0, x) return level_outputs def _up_decode(self, level_outputs, r_embed, clip): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward for i, (up_block, upscaler, repmap) in enumerate(block_group): for j in range(len(repmap) + 1): for k, block in enumerate(up_block): if isinstance(block, SDCascadeResBlock): skip = level_outputs[i] if k == 0 and i > 0 else None if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): orig_type = x.dtype x = torch.nn.functional.interpolate( x.float(), skip.shape[-2:], mode="bilinear", align_corners=True ) x = x.to(orig_type) x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, skip, use_reentrant=False ) elif isinstance(block, SDCascadeAttnBlock): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, clip, use_reentrant=False ) elif isinstance(block, SDCascadeTimestepBlock): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, r_embed, use_reentrant=False ) else: x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) if j < len(repmap): x = repmap[j](x) x = upscaler(x) else: for i, (up_block, upscaler, repmap) in enumerate(block_group): for j in range(len(repmap) + 1): for k, block in enumerate(up_block): if isinstance(block, SDCascadeResBlock): skip = level_outputs[i] if k == 0 and i > 0 else None if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): orig_type = x.dtype x = torch.nn.functional.interpolate( x.float(), skip.shape[-2:], mode="bilinear", align_corners=True ) x = x.to(orig_type) x = block(x, skip) elif isinstance(block, SDCascadeAttnBlock): x = block(x, clip) elif isinstance(block, SDCascadeTimestepBlock): x = block(x, r_embed) else: x = block(x) if j < len(repmap): x = repmap[j](x) x = upscaler(x) return x def forward( self, sample, timestep_ratio, clip_text_pooled, clip_text=None, clip_img=None, effnet=None, pixels=None, sca=None, crp=None, return_dict=True, ): if pixels is None: pixels = sample.new_zeros(sample.size(0), 3, 8, 8) # Process the conditioning embeddings timestep_ratio_embed = self.get_timestep_ratio_embedding(timestep_ratio) for c in self.config.timestep_conditioning_type: if c == "sca": cond = sca elif c == "crp": cond = crp else: cond = None t_cond = cond or torch.zeros_like(timestep_ratio) timestep_ratio_embed = torch.cat([timestep_ratio_embed, self.get_timestep_ratio_embedding(t_cond)], dim=1) clip = self.get_clip_embeddings(clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img) # Model Blocks x = self.embedding(sample) if hasattr(self, "effnet_mapper") and effnet is not None: x = x + self.effnet_mapper( nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True) ) if hasattr(self, "pixels_mapper"): x = x + nn.functional.interpolate( self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True ) level_outputs = self._down_encode(x, timestep_ratio_embed, clip) x = self._up_decode(level_outputs, timestep_ratio_embed, clip) sample = self.clf(x) if not return_dict: return (sample,) return StableCascadeUNetOutput(sample=sample) ================================================ FILE: src/diffusers/models/unets/uvit_2d.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Union import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ..attention import BasicTransformerBlock, SkipFFTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) from ..embeddings import TimestepEmbedding, get_timestep_embedding from ..modeling_utils import ModelMixin from ..normalization import GlobalResponseNorm, RMSNorm from ..resnet import Downsample2D, Upsample2D class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, # global config hidden_size: int = 1024, use_bias: bool = False, hidden_dropout: float = 0.0, # conditioning dimensions cond_embed_dim: int = 768, micro_cond_encode_dim: int = 256, micro_cond_embed_dim: int = 1280, encoder_hidden_size: int = 768, # num tokens vocab_size: int = 8256, # codebook_size + 1 (for the mask token) rounded codebook_size: int = 8192, # `UVit2DConvEmbed` in_channels: int = 768, block_out_channels: int = 768, num_res_blocks: int = 3, downsample: bool = False, upsample: bool = False, block_num_heads: int = 12, # `TransformerLayer` num_hidden_layers: int = 22, num_attention_heads: int = 16, # `Attention` attention_dropout: float = 0.0, # `FeedForward` intermediate_size: int = 2816, # `Norm` layer_norm_eps: float = 1e-6, ln_elementwise_affine: bool = True, sample_size: int = 64, ): super().__init__() self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias) self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) self.embed = UVit2DConvEmbed( in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias ) self.cond_embed = TimestepEmbedding( micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias ) self.down_block = UVitBlock( block_out_channels, num_res_blocks, hidden_size, hidden_dropout, ln_elementwise_affine, layer_norm_eps, use_bias, block_num_heads, attention_dropout, downsample, False, ) self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine) self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias) self.transformer_layers = nn.ModuleList( [ BasicTransformerBlock( dim=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=hidden_size // num_attention_heads, dropout=hidden_dropout, cross_attention_dim=hidden_size, attention_bias=use_bias, norm_type="ada_norm_continuous", ada_norm_continous_conditioning_embedding_dim=hidden_size, norm_elementwise_affine=ln_elementwise_affine, norm_eps=layer_norm_eps, ada_norm_bias=use_bias, ff_inner_dim=intermediate_size, ff_bias=use_bias, attention_out_bias=use_bias, ) for _ in range(num_hidden_layers) ] ) self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias) self.up_block = UVitBlock( block_out_channels, num_res_blocks, hidden_size, hidden_dropout, ln_elementwise_affine, layer_norm_eps, use_bias, block_num_heads, attention_dropout, downsample=False, upsample=upsample, ) self.mlm_layer = ConvMlmLayer( block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size ) self.gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value: bool = False) -> None: pass def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): encoder_hidden_states = self.encoder_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) micro_cond_embeds = get_timestep_embedding( micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0 ) micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1)) pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1) pooled_text_emb = pooled_text_emb.to(dtype=self.dtype) pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype) hidden_states = self.embed(input_ids) hidden_states = self.down_block( hidden_states, pooled_text_emb=pooled_text_emb, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ) batch_size, channels, height, width = hidden_states.shape hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) hidden_states = self.project_to_hidden_norm(hidden_states) hidden_states = self.project_to_hidden(hidden_states) for layer in self.transformer_layers: if self.training and self.gradient_checkpointing: def layer_(*args): return checkpoint(layer, *args) else: layer_ = layer hidden_states = layer_( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs={"pooled_text_emb": pooled_text_emb}, ) hidden_states = self.project_from_hidden_norm(hidden_states) hidden_states = self.project_from_hidden(hidden_states) hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) hidden_states = self.up_block( hidden_states, pooled_text_emb=pooled_text_emb, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ) logits = self.mlm_layer(hidden_states) return logits @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) class UVit2DConvEmbed(nn.Module): def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias): super().__init__() self.embeddings = nn.Embedding(vocab_size, in_channels) self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine) self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias) def forward(self, input_ids): embeddings = self.embeddings(input_ids) embeddings = self.layer_norm(embeddings) embeddings = embeddings.permute(0, 3, 1, 2) embeddings = self.conv(embeddings) return embeddings class UVitBlock(nn.Module): def __init__( self, channels, num_res_blocks: int, hidden_size, hidden_dropout, ln_elementwise_affine, layer_norm_eps, use_bias, block_num_heads, attention_dropout, downsample: bool, upsample: bool, ): super().__init__() if downsample: self.downsample = Downsample2D( channels, use_conv=True, padding=0, name="Conv2d_0", kernel_size=2, norm_type="rms_norm", eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine, bias=use_bias, ) else: self.downsample = None self.res_blocks = nn.ModuleList( [ ConvNextBlock( channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, ) for i in range(num_res_blocks) ] ) self.attention_blocks = nn.ModuleList( [ SkipFFTransformerBlock( channels, block_num_heads, channels // block_num_heads, hidden_size, use_bias, attention_dropout, channels, attention_bias=use_bias, attention_out_bias=use_bias, ) for _ in range(num_res_blocks) ] ) if upsample: self.upsample = Upsample2D( channels, use_conv_transpose=True, kernel_size=2, padding=0, name="conv", norm_type="rms_norm", eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine, bias=use_bias, interpolate=False, ) else: self.upsample = None def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs): if self.downsample is not None: x = self.downsample(x) for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): x = res_block(x, pooled_text_emb) batch_size, channels, height, width = x.shape x = x.view(batch_size, channels, height * width).permute(0, 2, 1) x = attention_block( x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs ) x = x.permute(0, 2, 1).view(batch_size, channels, height, width) if self.upsample is not None: x = self.upsample(x) return x class ConvNextBlock(nn.Module): def __init__( self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4 ): super().__init__() self.depthwise = nn.Conv2d( channels, channels, kernel_size=3, padding=1, groups=channels, bias=use_bias, ) self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine) self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias) self.channelwise_act = nn.GELU() self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias) self.channelwise_dropout = nn.Dropout(hidden_dropout) self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) def forward(self, x, cond_embeds): x_res = x x = self.depthwise(x) x = x.permute(0, 2, 3, 1) x = self.norm(x) x = self.channelwise_linear_1(x) x = self.channelwise_act(x) x = self.channelwise_norm(x) x = self.channelwise_linear_2(x) x = self.channelwise_dropout(x) x = x.permute(0, 3, 1, 2) x = x + x_res scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] return x class ConvMlmLayer(nn.Module): def __init__( self, block_out_channels: int, in_channels: int, use_bias: bool, ln_elementwise_affine: bool, layer_norm_eps: float, codebook_size: int, ): super().__init__() self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias) self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine) self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias) def forward(self, hidden_states): hidden_states = self.conv1(hidden_states) hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) logits = self.conv2(hidden_states) return logits ================================================ FILE: src/diffusers/models/upsampling.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from ..utils import deprecate from .normalization import RMSNorm class Upsample1D(nn.Module): """A 1D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. name (`str`, default `conv`): name of the upsampling 1D layer. """ def __init__( self, channels: int, use_conv: bool = False, use_conv_transpose: bool = False, out_channels: Optional[int] = None, name: str = "conv", ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name self.conv = None if use_conv_transpose: self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) elif use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) def forward(self, inputs: torch.Tensor) -> torch.Tensor: assert inputs.shape[1] == self.channels if self.use_conv_transpose: return self.conv(inputs) outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") if self.use_conv: outputs = self.conv(outputs) return outputs class Upsample2D(nn.Module): """A 2D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. name (`str`, default `conv`): name of the upsampling 2D layer. """ def __init__( self, channels: int, use_conv: bool = False, use_conv_transpose: bool = False, out_channels: Optional[int] = None, name: str = "conv", kernel_size: Optional[int] = None, padding=1, norm_type=None, eps=None, elementwise_affine=None, bias=True, interpolate=True, ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name self.interpolate = interpolate if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine) elif norm_type == "rms_norm": self.norm = RMSNorm(channels, eps, elementwise_affine) elif norm_type is None: self.norm = None else: raise ValueError(f"unknown norm_type: {norm_type}") conv = None if use_conv_transpose: if kernel_size is None: kernel_size = 4 conv = nn.ConvTranspose2d( channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias ) elif use_conv: if kernel_size is None: kernel_size = 3 conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.conv = conv else: self.Conv2d_0 = conv def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) assert hidden_states.shape[1] == self.channels if self.norm is not None: hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) if self.use_conv_transpose: return self.conv(hidden_states) # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch # https://github.com/pytorch/pytorch/issues/86679 dtype = hidden_states.dtype if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: hidden_states = hidden_states.contiguous() # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if self.interpolate: if output_size is None: hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # If the input is bfloat16, we cast back to bfloat16 if dtype == torch.bfloat16: hidden_states = hidden_states.to(dtype) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": hidden_states = self.conv(hidden_states) else: hidden_states = self.Conv2d_0(hidden_states) return hidden_states class FirUpsample2D(nn.Module): """A 2D FIR upsampling layer with an optional convolution. Parameters: channels (`int`, optional): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. fir_kernel (`tuple`, default `(1, 3, 3, 1)`): kernel for the FIR filter. """ def __init__( self, channels: Optional[int] = None, out_channels: Optional[int] = None, use_conv: bool = False, fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), ): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.use_conv = use_conv self.fir_kernel = fir_kernel self.out_channels = out_channels def _upsample_2d( self, hidden_states: torch.Tensor, weight: Optional[torch.Tensor] = None, kernel: Optional[torch.Tensor] = None, factor: int = 2, gain: float = 1, ) -> torch.Tensor: """Fused `upsample_2d()` followed by `Conv2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: hidden_states (`torch.Tensor`): Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. weight (`torch.Tensor`, *optional*): Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. kernel (`torch.Tensor`, *optional*): FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor (`int`, *optional*): Integer upsampling factor (default: 2). gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0). Returns: output (`torch.Tensor`): Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as `hidden_states`. """ assert isinstance(factor, int) and factor >= 1 # Setup filter kernel. if kernel is None: kernel = [1] * factor # setup kernel kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) if self.use_conv: convH = weight.shape[2] convW = weight.shape[3] inC = weight.shape[1] pad_value = (kernel.shape[0] - factor) - (convW - 1) stride = (factor, factor) # Determine data dimensions. output_shape = ( (hidden_states.shape[2] - 1) * factor + convH, (hidden_states.shape[3] - 1) * factor + convW, ) output_padding = ( output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, ) assert output_padding[0] >= 0 and output_padding[1] >= 0 num_groups = hidden_states.shape[1] // inC # Transpose weights. weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) inverse_conv = F.conv_transpose2d( hidden_states, weight, stride=stride, output_padding=output_padding, padding=0, ) output = upfirdn2d_native( inverse_conv, torch.tensor(kernel, device=inverse_conv.device), pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), ) else: pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) return output def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.use_conv: height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) return height class KUpsample2D(nn.Module): r"""A 2D K-upsampling layer. Parameters: pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. """ def __init__(self, pad_mode: str = "reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) def forward(self, inputs: torch.Tensor) -> torch.Tensor: inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) weight = inputs.new_zeros( [ inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1], ] ) indices = torch.arange(inputs.shape[1], device=inputs.device) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) class CogVideoXUpsample3D(nn.Module): r""" A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. Args: in_channels (`int`): Number of channels in the input image. out_channels (`int`): Number of channels produced by the convolution. kernel_size (`int`, defaults to `3`): Size of the convolving kernel. stride (`int`, defaults to `1`): Stride of the convolution. padding (`int`, defaults to `1`): Padding added to all four sides of the input. compress_time (`bool`, defaults to `False`): Whether or not to compress the time dimension. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 1, compress_time: bool = False, ) -> None: super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.compress_time = compress_time def forward(self, inputs: torch.Tensor) -> torch.Tensor: if self.compress_time: if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: # split first frame x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] x_first = F.interpolate(x_first, scale_factor=2.0) x_rest = F.interpolate(x_rest, scale_factor=2.0) x_first = x_first[:, :, None, :, :] inputs = torch.cat([x_first, x_rest], dim=2) elif inputs.shape[2] > 1: inputs = F.interpolate(inputs, scale_factor=2.0) else: inputs = inputs.squeeze(2) inputs = F.interpolate(inputs, scale_factor=2.0) inputs = inputs[:, :, None, :, :] else: # only interpolate 2D b, c, t, h, w = inputs.shape inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) inputs = F.interpolate(inputs, scale_factor=2.0) inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) b, c, t, h, w = inputs.shape inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) inputs = self.conv(inputs) inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) return inputs def upfirdn2d_native( tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0), ) -> torch.Tensor: up_x = up_y = up down_x = down_y = down pad_x0 = pad_y0 = pad[0] pad_x1 = pad_y1 = pad[1] _, channel, in_h, in_w = tensor.shape tensor = tensor.reshape(-1, in_h, in_w, 1) _, in_h, in_w, minor = tensor.shape kernel_h, kernel_w = kernel.shape out = tensor.view(-1, in_h, 1, in_w, 1, minor) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) out = out.to(tensor.device) # Move back to mps if necessary out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :, ] out = out.permute(0, 3, 1, 2) out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( -1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) out = out.permute(0, 2, 3, 1) out = out[:, ::down_y, ::down_x, :] out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.view(-1, channel, out_h, out_w) def upsample_2d( hidden_states: torch.Tensor, kernel: Optional[torch.Tensor] = None, factor: int = 2, gain: float = 1, ) -> torch.Tensor: r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: multiple of the upsampling factor. Args: hidden_states (`torch.Tensor`): Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel (`torch.Tensor`, *optional*): FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor (`int`, *optional*, default to `2`): Integer upsampling factor. gain (`float`, *optional*, default to `1.0`): Scaling factor for signal magnitude (default: 1.0). Returns: output (`torch.Tensor`): Tensor of the shape `[N, C, H * factor, W * factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) return output ================================================ FILE: src/diffusers/models/vae_flax.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers import math from functools import partial from typing import Tuple import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..utils import BaseOutput from .modeling_flax_utils import FlaxModelMixin @flax.struct.dataclass class FlaxDecoderOutput(BaseOutput): """ Output of decoding method. Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): The decoded output sample from the last layer of the model. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): The `dtype` of the parameters. """ sample: jnp.ndarray @flax.struct.dataclass class FlaxAutoencoderKLOutput(BaseOutput): """ Output of AutoencoderKL encoding method. Args: latent_dist (`FlaxDiagonalGaussianDistribution`): Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. """ latent_dist: "FlaxDiagonalGaussianDistribution" class FlaxUpsample2D(nn.Module): """ Flax implementation of 2D Upsample layer Args: in_channels (`int`): Input channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, hidden_states): batch, height, width, channels = hidden_states.shape hidden_states = jax.image.resize( hidden_states, shape=(batch, height * 2, width * 2, channels), method="nearest", ) hidden_states = self.conv(hidden_states) return hidden_states class FlaxDownsample2D(nn.Module): """ Flax implementation of 2D Downsample layer Args: in_channels (`int`): Input channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), strides=(2, 2), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states): pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim hidden_states = jnp.pad(hidden_states, pad_width=pad) hidden_states = self.conv(hidden_states) return hidden_states class FlaxResnetBlock2D(nn.Module): """ Flax implementation of 2D Resnet Block. Args: in_channels (`int`): Input channels out_channels (`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm. use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): Whether to use `nin_shortcut`. This activates a new layer inside ResNet block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int = None dropout: float = 0.0 groups: int = 32 use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 def setup(self): out_channels = self.in_channels if self.out_channels is None else self.out_channels self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.conv1 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.dropout_layer = nn.Dropout(self.dropout) self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut self.conv_shortcut = None if use_nin_shortcut: self.conv_shortcut = nn.Conv( out_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states, deterministic=True): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) return hidden_states + residual class FlaxAttentionBlock(nn.Module): r""" Flax Convolutional based multi-head attention block for diffusion-based VAE. Parameters: channels (:obj:`int`): Input channels num_head_channels (:obj:`int`, *optional*, defaults to `None`): Number of attention heads num_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ channels: int num_head_channels: int = None num_groups: int = 32 dtype: jnp.dtype = jnp.float32 def setup(self): self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 dense = partial(nn.Dense, self.channels, dtype=self.dtype) self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) self.query, self.key, self.value = dense(), dense(), dense() self.proj_attn = dense() def transpose_for_scores(self, projection): new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) new_projection = projection.reshape(new_projection_shape) # (B, T, H, D) -> (B, H, T, D) new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) return new_projection def __call__(self, hidden_states): residual = hidden_states batch, height, width, channels = hidden_states.shape hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.reshape((batch, height * width, channels)) query = self.query(hidden_states) key = self.key(hidden_states) value = self.value(hidden_states) # transpose query = self.transpose_for_scores(query) key = self.transpose_for_scores(key) value = self.transpose_for_scores(value) # compute attentions scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) attn_weights = nn.softmax(attn_weights, axis=-1) # attend to values hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) hidden_states = hidden_states.reshape(new_hidden_states_shape) hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.reshape((batch, height, width, channels)) hidden_states = hidden_states + residual return hidden_states class FlaxDownEncoderBlock2D(nn.Module): r""" Flax Resnet blocks-based Encoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 add_downsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) return hidden_states class FlaxUpDecoderBlock2D(nn.Module): r""" Flax Resnet blocks-based Decoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 add_upsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUNetMidBlock2D(nn.Module): r""" Flax Unet Mid-Block module. Parameters: in_channels (:obj:`int`): Input channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet and Attention block group norm num_attention_heads (:obj:`int`, *optional*, defaults to `1`): Number of attention heads for each attention block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 num_attention_heads: int = 1 dtype: jnp.dtype = jnp.float32 def setup(self): resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) # there is always at least one resnet resnets = [ FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, groups=resnet_groups, dtype=self.dtype, ) ] attentions = [] for _ in range(self.num_layers): attn_block = FlaxAttentionBlock( channels=self.in_channels, num_head_channels=self.num_attention_heads, num_groups=resnet_groups, dtype=self.dtype, ) attentions.append(attn_block) res_block = FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, groups=resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets self.attentions = attentions def __call__(self, hidden_states, deterministic=True): hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states, deterministic=deterministic) return hidden_states class FlaxEncoder(nn.Module): r""" Flax Implementation of VAE Encoder. This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (:obj:`int`, *optional*, defaults to 3): Input channels out_channels (:obj:`int`, *optional*, defaults to 3): Output channels down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): DownEncoder block type block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function double_z (:obj:`bool`, *optional*, defaults to `False`): Whether to double the last output channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) block_out_channels: Tuple[int] = (64,) layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" double_z: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels # in self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # downsampling down_blocks = [] output_channel = block_out_channels[0] for i, _ in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = FlaxDownEncoderBlock2D( in_channels=input_channel, out_channels=output_channel, num_layers=self.layers_per_block, resnet_groups=self.norm_num_groups, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) self.down_blocks = down_blocks # middle self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_groups=self.norm_num_groups, num_attention_heads=None, dtype=self.dtype, ) # end conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( conv_out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, sample, deterministic: bool = True): # in sample = self.conv_in(sample) # downsampling for block in self.down_blocks: sample = block(sample, deterministic=deterministic) # middle sample = self.mid_block(sample, deterministic=deterministic) # end sample = self.conv_norm_out(sample) sample = nn.swish(sample) sample = self.conv_out(sample) return sample class FlaxDecoder(nn.Module): r""" Flax Implementation of VAE Decoder. This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (:obj:`int`, *optional*, defaults to 3): Input channels out_channels (:obj:`int`, *optional*, defaults to 3): Output channels up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): UpDecoder block type block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function double_z (:obj:`bool`, *optional*, defaults to `False`): Whether to double the last output channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 up_block_types: Tuple[str] = ("UpDecoderBlock2D",) block_out_channels: int = (64,) layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels # z to block_in self.conv_in = nn.Conv( block_out_channels[-1], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # middle self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_groups=self.norm_num_groups, num_attention_heads=None, dtype=self.dtype, ) # upsampling reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] up_blocks = [] for i, _ in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = FlaxUpDecoderBlock2D( in_channels=prev_output_channel, out_channels=output_channel, num_layers=self.layers_per_block + 1, resnet_groups=self.norm_num_groups, add_upsample=not is_final_block, dtype=self.dtype, ) up_blocks.append(up_block) prev_output_channel = output_channel self.up_blocks = up_blocks # end self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, sample, deterministic: bool = True): # z to block_in sample = self.conv_in(sample) # middle sample = self.mid_block(sample, deterministic=deterministic) # upsampling for block in self.up_blocks: sample = block(sample, deterministic=deterministic) sample = self.conv_norm_out(sample) sample = nn.swish(sample) sample = self.conv_out(sample) return sample class FlaxDiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): # Last axis to account for channels-last self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) self.logvar = jnp.clip(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = jnp.exp(0.5 * self.logvar) self.var = jnp.exp(self.logvar) if self.deterministic: self.var = self.std = jnp.zeros_like(self.mean) def sample(self, key): return self.mean + self.std * jax.random.normal(key, self.mean.shape) def kl(self, other=None): if self.deterministic: return jnp.array([0.0]) if other is None: return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) return 0.5 * jnp.sum( jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, axis=[1, 2, 3], ) def nll(self, sample, axis=[1, 2, 3]): if self.deterministic: return jnp.array([0.0]) logtwopi = jnp.log(2.0 * jnp.pi) return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) def mode(self): return self.mean @flax_register_to_config class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): r""" Flax implementation of a VAE model with KL loss for decoding latent representations. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its general usage and behavior. Inherent JAX features such as the following are supported: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): Tuple of upsample block types. block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple of block output channels. layers_per_block (`int`, *optional*, defaults to `2`): Number of ResNet layer for each block. act_fn (`str`, *optional*, defaults to `silu`): The activation function to use. latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. sample_size (`int`, *optional*, defaults to 32): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): The `dtype` of the parameters. """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) up_block_types: Tuple[str] = ("UpDecoderBlock2D",) block_out_channels: Tuple[int] = (64,) layers_per_block: int = 1 act_fn: str = "silu" latent_channels: int = 4 norm_num_groups: int = 32 sample_size: int = 32 scaling_factor: float = 0.18215 dtype: jnp.dtype = jnp.float32 def setup(self): self.encoder = FlaxEncoder( in_channels=self.config.in_channels, out_channels=self.config.latent_channels, down_block_types=self.config.down_block_types, block_out_channels=self.config.block_out_channels, layers_per_block=self.config.layers_per_block, act_fn=self.config.act_fn, norm_num_groups=self.config.norm_num_groups, double_z=True, dtype=self.dtype, ) self.decoder = FlaxDecoder( in_channels=self.config.latent_channels, out_channels=self.config.out_channels, up_block_types=self.config.up_block_types, block_out_channels=self.config.block_out_channels, layers_per_block=self.config.layers_per_block, norm_num_groups=self.config.norm_num_groups, act_fn=self.config.act_fn, dtype=self.dtype, ) self.quant_conv = nn.Conv( 2 * self.config.latent_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.post_quant_conv = nn.Conv( self.config.latent_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def init_weights(self, rng: jax.Array) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng} return self.init(rngs, sample)["params"] def encode(self, sample, deterministic: bool = True, return_dict: bool = True): sample = jnp.transpose(sample, (0, 2, 3, 1)) hidden_states = self.encoder(sample, deterministic=deterministic) moments = self.quant_conv(hidden_states) posterior = FlaxDiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) def decode(self, latents, deterministic: bool = True, return_dict: bool = True): if latents.shape[-1] != self.config.latent_channels: latents = jnp.transpose(latents, (0, 2, 3, 1)) hidden_states = self.post_quant_conv(latents) hidden_states = self.decoder(hidden_states, deterministic=deterministic) hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) if not return_dict: return (hidden_states,) return FlaxDecoderOutput(sample=hidden_states) def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) if sample_posterior: rng = self.make_rng("gaussian") hidden_states = posterior.latent_dist.sample(rng) else: hidden_states = posterior.latent_dist.mode() sample = self.decode(hidden_states, return_dict=return_dict).sample if not return_dict: return (sample,) return FlaxDecoderOutput(sample=sample) ================================================ FILE: src/diffusers/models/vq_model.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..utils import deprecate from .autoencoders.vq_model import VQEncoderOutput, VQModel class VQEncoderOutput(VQEncoderOutput): def __init__(self, *args, **kwargs): deprecation_message = "Importing `VQEncoderOutput` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQEncoderOutput`, instead." deprecate("VQEncoderOutput", "0.31", deprecation_message) super().__init__(*args, **kwargs) class VQModel(VQModel): def __init__(self, *args, **kwargs): deprecation_message = "Importing `VQModel` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQModel`, instead." deprecate("VQModel", "0.31", deprecation_message) super().__init__(*args, **kwargs) ================================================ FILE: src/diffusers/optimization.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch optimization for diffusion models.""" import math from enum import Enum from typing import Optional, Union from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR from .utils import logging logger = logging.get_logger(__name__) class SchedulerType(Enum): LINEAR = "linear" COSINE = "cosine" COSINE_WITH_RESTARTS = "cosine_with_restarts" POLYNOMIAL = "polynomial" CONSTANT = "constant" CONSTANT_WITH_WARMUP = "constant_with_warmup" PIECEWISE_CONSTANT = "piecewise_constant" def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaLR: """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1) -> LambdaLR: """ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1) -> LambdaLR: """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. step_rules (`string`): The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ rules_dict = {} rule_list = step_rules.split(",") for rule_str in rule_list[:-1]: value_str, steps_str = rule_str.split(":") steps = int(steps_str) value = float(value_str) rules_dict[steps] = value last_lr_multiple = float(rule_list[-1]) def create_rules_function(rules_dict, last_lr_multiple): def rule_func(steps: int) -> float: sorted_steps = sorted(rules_dict.keys()) for i, sorted_step in enumerate(sorted_steps): if steps < sorted_step: return rules_dict[sorted_steps[i]] return last_lr_multiple return rule_func rules_func = create_rules_function(rules_dict, last_lr_multiple) return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) def get_linear_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1 ) -> LambdaLR: """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) ) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_cosine_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 ) -> LambdaLR: """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_periods (`float`, *optional*, defaults to 0.5): The number of periods of the cosine function in a schedule (the default is to just decrease from the max value to 0 following a half-cosine). last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_cosine_with_hard_restarts_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 ) -> LambdaLR: """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_cycles (`int`, *optional*, defaults to 1): The number of hard restarts to use. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) if progress >= 1.0: return 0.0 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_polynomial_decay_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, lr_end: float = 1e-7, power: float = 1.0, last_epoch: int = -1, ) -> LambdaLR: """ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. lr_end (`float`, *optional*, defaults to 1e-7): The end LR. power (`float`, *optional*, defaults to 1.0): Power factor. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ lr_init = optimizer.defaults["lr"] if not (lr_init > lr_end): raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) elif current_step > num_training_steps: return lr_end / lr_init # as LambdaLR multiplies by lr_init else: lr_range = lr_init - lr_end decay_steps = num_training_steps - num_warmup_steps pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps decay = lr_range * pct_remaining**power + lr_end return decay / lr_init # as LambdaLR multiplies by lr_init return LambdaLR(optimizer, lr_lambda, last_epoch) TYPE_TO_SCHEDULER_FUNCTION = { SchedulerType.LINEAR: get_linear_schedule_with_warmup, SchedulerType.COSINE: get_cosine_schedule_with_warmup, SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, SchedulerType.CONSTANT: get_constant_schedule, SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule, } def get_scheduler( name: Union[str, SchedulerType], optimizer: Optimizer, step_rules: Optional[str] = None, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, num_cycles: int = 1, power: float = 1.0, last_epoch: int = -1, ) -> LambdaLR: """ Unified API to get any scheduler from its name. Args: name (`str` or `SchedulerType`): The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. step_rules (`str`, *optional*): A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler. num_warmup_steps (`int`, *optional*): The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_training_steps (`int``, *optional*): The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_cycles (`int`, *optional*): The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. power (`float`, *optional*, defaults to 1.0): Power factor. See `POLYNOMIAL` scheduler last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return schedule_func(optimizer, last_epoch=last_epoch) if name == SchedulerType.PIECEWISE_CONSTANT: return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch) # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch) # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") if name == SchedulerType.COSINE_WITH_RESTARTS: return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, last_epoch=last_epoch, ) if name == SchedulerType.POLYNOMIAL: return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, last_epoch=last_epoch, ) return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch ) ================================================ FILE: src/diffusers/pipelines/README.md ================================================ # 🧨 Diffusers Pipelines Pipelines provide a simple way to run state-of-the-art diffusion models in inference. Most diffusion systems consist of multiple independently-trained models and highly adaptable scheduler components - all of which are needed to have a functioning end-to-end diffusion system. As an example, [Stable Diffusion](https://huggingface.co/blog/stable_diffusion) has three independently trained models: - [Autoencoder](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/vae.py#L392) - [Conditional Unet](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/unet_2d_condition.py#L12) - [CLIP text encoder](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTextModel) - a scheduler component, [scheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py), - a [CLIPImageProcessor](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPImageProcessor), - as well as a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py). All of these components are necessary to run stable diffusion in inference even though they were trained or created independently from each other. To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API. More specifically, we strive to provide pipelines that - 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)), - 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section), - 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)), - 4. can easily be contributed by the community (see the [Contribution](#contribution) section). **Note** that pipelines do not (and should not) offer any training functionality. If you are looking for *official* training examples, please have a look at [examples](https://github.com/huggingface/diffusers/tree/main/examples). ## Pipelines Summary The following table summarizes all officially supported pipelines, their corresponding paper, and if available a colab notebook to directly try them out. | Pipeline | Source | Tasks | Colab |-------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:---:|:---:| | [dance diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/Harmonai-org/sample-generator) | *Unconditional Audio Generation* | | [ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | *Unconditional Image Generation* | | [ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | *Unconditional Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | [latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Text-to-Image Generation* | | [latent_diffusion_uncond](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Unconditional Image Generation* | | [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* | | [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | | [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | | [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) | [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. However, most of them can be adapted to use different scheduler components or even different model components. Some pipeline examples are shown in the [Examples](#examples) below. ## Pipelines API Diffusion models often consist of multiple independently-trained models or other previously existing components. Each model has been trained independently on a different task and the scheduler can easily be swapped out and replaced with a different one. During inference, we however want to be able to easily load all components and use them in inference - even if one component, *e.g.* CLIP's text encoder, originates from a different library, such as [Transformers](https://github.com/huggingface/transformers). To that end, all pipelines provide the following functionality: - [`from_pretrained` method](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L139) that accepts a Hugging Face Hub repository id, *e.g.* [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) or a path to a local directory, *e.g.* "./stable-diffusion". To correctly retrieve which models and components should be loaded, one has to provide a `model_index.json` file, *e.g.* [runwayml/stable-diffusion-v1-5/model_index.json](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), which defines all components that should be loaded into the pipelines. More specifically, for each model/component one needs to define the format `: ["", ""]`. `` is the attribute name given to the loaded instance of `` which can be found in the library or pipeline folder called `""`. - [`save_pretrained`](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L90) that accepts a local path, *e.g.* `./stable-diffusion` under which all models/components of the pipeline will be saved. For each component/model a folder is created inside the local path that is named after the given attribute name, *e.g.* `./stable_diffusion/unet`. In addition, a `model_index.json` file is created at the root of the local path, *e.g.* `./stable_diffusion/model_index.json` so that the complete pipeline can again be instantiated from the local path. - [`to`](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L118) which accepts a `string` or `torch.device` to move all models that are of type `torch.nn.Module` to the passed device. The behavior is fully analogous to [PyTorch's `to` method](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to). - [`__call__`] method to use the pipeline in inference. `__call__` defines inference logic of the pipeline and should ideally encompass all aspects of it, from pre-processing to forwarding tensors to the different models and schedulers, as well as post-processing. The API of the `__call__` method can strongly vary from pipeline to pipeline. *E.g.* a text-to-image pipeline, such as [`StableDiffusionPipeline`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) should accept among other things the text prompt to generate the image. A pure image generation pipeline, such as [DDPMPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/ddpm) on the other hand can be run without providing any inputs. To better understand what inputs can be adapted for each pipeline, one should look directly into the respective pipeline. **Note**: All pipelines have PyTorch's autograd disabled by decorating the `__call__` method with a [`torch.no_grad`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) decorator because pipelines should not be used for training. If you want to store the gradients during the forward pass, we recommend writing your own pipeline, see also our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community) ## Contribution We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**. - **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline. - **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method. - **Easy-to-tweak**: Certain pipelines will not be able to handle all use cases and tasks that you might like them to. If you want to use a certain pipeline for a specific use case that is not yet supported, you might have to copy the pipeline file and tweak the code to your needs. We try to make the pipeline code as readable as possible so that each part –from pre-processing to diffusing to post-processing– can easily be adapted. If you would like the community to benefit from your customized pipeline, we would love to see a contribution to our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community). If you feel that an important pipeline should be part of the official pipelines but isn't, a contribution to the [official pipelines](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines) would be even better. - **One-purpose-only**: Pipelines should be used for one task and one task only. Even if two tasks are very similar from a modeling point of view, *e.g.* image2image translation and in-painting, pipelines shall be used for one task only to keep them *easy-to-tweak* and *readable*. ## Examples ### Text-to-Image generation with Stable Diffusion ```python # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` ### Image-to-Image text-guided generation with Stable Diffusion The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. ```python import requests from PIL import Image from io import BytesIO from diffusers import StableDiffusionImg2ImgPipeline # load the pipeline device = "cuda" pipe = StableDiffusionImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, ).to(device) # let's download an initial image url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" response = requests.get(url) init_image = Image.open(BytesIO(response.content)).convert("RGB") init_image = init_image.resize((768, 512)) prompt = "A fantasy landscape, trending on artstation" images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images images[0].save("fantasy_landscape.png") ``` You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ### Tweak prompts reusing seeds and latents You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb). ### In-painting using Stable Diffusion The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by providing a mask and text prompt. ```python import PIL import requests import torch from io import BytesIO from diffusers import StableDiffusionInpaintPipeline def download_image(url): response = requests.get(url) return PIL.Image.open(BytesIO(response.content)).convert("RGB") img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" init_image = download_image(img_url).resize((512, 512)) mask_image = download_image(mask_url).resize((512, 512)) pipe = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, ) pipe = pipe.to("cuda") prompt = "Face of a yellow cat, high resolution, sitting on a park bench" image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] ``` You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) ================================================ FILE: src/diffusers/pipelines/__init__.py ================================================ from typing import TYPE_CHECKING from ..utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_flax_available, is_k_diffusion_available, is_librosa_available, is_note_seq_available, is_onnx_available, is_sentencepiece_available, is_torch_available, is_torch_npu_available, is_transformers_available, ) # These modules contain pipelines from multiple libraries/frameworks _dummy_objects = {} _import_structure = { "controlnet": [], "controlnet_hunyuandit": [], "controlnet_sd3": [], "controlnet_xs": [], "deprecated": [], "latent_diffusion": [], "ledits_pp": [], "marigold": [], "pag": [], "stable_diffusion": [], "stable_diffusion_xl": [], } try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_pt_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) else: _import_structure["auto_pipeline"] = [ "AutoPipelineForImage2Image", "AutoPipelineForInpainting", "AutoPipelineForText2Image", ] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["ddim"] = ["DDIMPipeline"] _import_structure["ddpm"] = ["DDPMPipeline"] _import_structure["dit"] = ["DiTPipeline"] _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) _import_structure["pipeline_utils"] = [ "AudioPipelineOutput", "DiffusionPipeline", "StableDiffusionMixin", "ImagePipelineOutput", ] _import_structure["deprecated"].extend( [ "PNDMPipeline", "LDMPipeline", "RePaintPipeline", "ScoreSdeVePipeline", "KarrasVePipeline", ] ) try: if not (is_torch_available() and is_librosa_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_torch_and_librosa_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) else: _import_structure["deprecated"].extend(["AudioDiffusionPipeline", "Mel"]) try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) else: _import_structure["deprecated"].extend( [ "MidiProcessor", "SpectrogramDiffusionPipeline", ] ) try: if not (is_torch_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["deprecated"].extend( [ "VQDiffusionPipeline", "AltDiffusionPipeline", "AltDiffusionImg2ImgPipeline", "CycleDiffusionPipeline", "StableDiffusionInpaintPipelineLegacy", "StableDiffusionPix2PixZeroPipeline", "StableDiffusionParadigmsPipeline", "StableDiffusionModelEditingPipeline", "VersatileDiffusionDualGuidedPipeline", "VersatileDiffusionImageVariationPipeline", "VersatileDiffusionPipeline", "VersatileDiffusionTextToImagePipeline", ] ) _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] _import_structure["animatediff"] = [ "AnimateDiffPipeline", "AnimateDiffControlNetPipeline", "AnimateDiffSDXLPipeline", "AnimateDiffSparseControlNetPipeline", "AnimateDiffVideoToVideoPipeline", ] _import_structure["flux"] = ["FluxPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", ] _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] _import_structure["cogvideo"] = ["CogVideoXPipeline"] _import_structure["controlnet"].extend( [ "BlipDiffusionControlNetPipeline", "StableDiffusionControlNetImg2ImgPipeline", "StableDiffusionControlNetInpaintPipeline", "StableDiffusionControlNetPipeline", "StableDiffusionXLControlNetImg2ImgPipeline", "StableDiffusionXLControlNetInpaintPipeline", "StableDiffusionXLControlNetPipeline", ] ) _import_structure["pag"].extend( [ "AnimateDiffPAGPipeline", "KolorsPAGPipeline", "HunyuanDiTPAGPipeline", "StableDiffusion3PAGPipeline", "StableDiffusionPAGPipeline", "StableDiffusionControlNetPAGPipeline", "StableDiffusionXLPAGPipeline", "StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLControlNetPAGPipeline", "StableDiffusionXLPAGImg2ImgPipeline", "PixArtSigmaPAGPipeline", ] ) _import_structure["controlnet_xs"].extend( [ "StableDiffusionControlNetXSPipeline", "StableDiffusionXLControlNetXSPipeline", ] ) _import_structure["controlnet_hunyuandit"].extend( [ "HunyuanDiTControlNetPipeline", ] ) _import_structure["controlnet_sd3"].extend( [ "StableDiffusion3ControlNetPipeline", ] ) _import_structure["deepfloyd_if"] = [ "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", "IFInpaintingPipeline", "IFInpaintingSuperResolutionPipeline", "IFPipeline", "IFSuperResolutionPipeline", ] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", "KandinskyImg2ImgPipeline", "KandinskyInpaintCombinedPipeline", "KandinskyInpaintPipeline", "KandinskyPipeline", "KandinskyPriorPipeline", ] _import_structure["kandinsky2_2"] = [ "KandinskyV22CombinedPipeline", "KandinskyV22ControlnetImg2ImgPipeline", "KandinskyV22ControlnetPipeline", "KandinskyV22Img2ImgCombinedPipeline", "KandinskyV22Img2ImgPipeline", "KandinskyV22InpaintCombinedPipeline", "KandinskyV22InpaintPipeline", "KandinskyV22Pipeline", "KandinskyV22PriorEmb2EmbPipeline", "KandinskyV22PriorPipeline", ] _import_structure["kandinsky3"] = [ "Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline", ] _import_structure["latent_consistency_models"] = [ "LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelPipeline", ] _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"]) _import_structure["ledits_pp"].extend( [ "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", ] ) _import_structure["latte"] = ["LattePipeline"] _import_structure["lumina"] = ["LuminaText2ImgPipeline"] _import_structure["marigold"].extend( [ "MarigoldDepthPipeline", "MarigoldNormalsPipeline", ] ) _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] _import_structure["stable_audio"] = [ "StableAudioProjectionModel", "StableAudioPipeline", ] _import_structure["stable_cascade"] = [ "StableCascadeCombinedPipeline", "StableCascadeDecoderPipeline", "StableCascadePriorPipeline", ] _import_structure["stable_diffusion"].extend( [ "CLIPImageProjection", "StableDiffusionDepth2ImgPipeline", "StableDiffusionImageVariationPipeline", "StableDiffusionImg2ImgPipeline", "StableDiffusionInpaintPipeline", "StableDiffusionInstructPix2PixPipeline", "StableDiffusionLatentUpscalePipeline", "StableDiffusionPipeline", "StableDiffusionUpscalePipeline", "StableUnCLIPImg2ImgPipeline", "StableUnCLIPPipeline", "StableDiffusionLDM3DPipeline", ] ) _import_structure["aura_flow"] = ["AuraFlowPipeline"] _import_structure["stable_diffusion_3"] = [ "StableDiffusion3Pipeline", "StableDiffusion3Img2ImgPipeline", "StableDiffusion3InpaintPipeline", ] _import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] _import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"] _import_structure["stable_diffusion_gligen"] = [ "StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENTextImagePipeline", ] _import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"] _import_structure["stable_diffusion_xl"].extend( [ "StableDiffusionXLImg2ImgPipeline", "StableDiffusionXLInpaintPipeline", "StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLPipeline", ] ) _import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] _import_structure["stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"] _import_structure["stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"] _import_structure["t2i_adapter"] = [ "StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline", ] _import_structure["text_to_video_synthesis"] = [ "TextToVideoSDPipeline", "TextToVideoZeroPipeline", "TextToVideoZeroSDXLPipeline", "VideoToVideoSDPipeline", ] _import_structure["i2vgen_xl"] = ["I2VGenXLPipeline"] _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] _import_structure["unidiffuser"] = [ "ImageTextPipelineOutput", "UniDiffuserModel", "UniDiffuserPipeline", "UniDiffuserTextDecoder", ] _import_structure["wuerstchen"] = [ "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", ] try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_onnx_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) else: _import_structure["onnx_utils"] = ["OnnxRuntimeModel"] try: if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) else: _import_structure["stable_diffusion"].extend( [ "OnnxStableDiffusionImg2ImgPipeline", "OnnxStableDiffusionInpaintPipeline", "OnnxStableDiffusionPipeline", "OnnxStableDiffusionUpscalePipeline", "StableDiffusionOnnxPipeline", ] ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import ( dummy_torch_and_transformers_and_k_diffusion_objects, ) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) else: _import_structure["stable_diffusion_k_diffusion"] = [ "StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline", ] try: if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import ( dummy_torch_and_transformers_and_sentencepiece_objects, ) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects)) else: _import_structure["kolors"] = [ "KolorsPipeline", "KolorsImg2ImgPipeline", ] try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_flax_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_flax_objects)) else: _import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"] try: if not (is_flax_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_flax_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) else: _import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"]) _import_structure["stable_diffusion"].extend( [ "FlaxStableDiffusionImg2ImgPipeline", "FlaxStableDiffusionInpaintPipeline", "FlaxStableDiffusionPipeline", ] ) _import_structure["stable_diffusion_xl"].extend( [ "FlaxStableDiffusionXLPipeline", ] ) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: from .auto_pipeline import ( AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image, ) from .consistency_models import ConsistencyModelPipeline from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .pipeline_utils import ( AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin, ) try: if not (is_torch_available() and is_librosa_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_librosa_objects import * else: from .deprecated import AudioDiffusionPipeline, Mel try: if not (is_torch_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_objects import * else: from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline from .animatediff import ( AnimateDiffControlNetPipeline, AnimateDiffPipeline, AnimateDiffSDXLPipeline, AnimateDiffSparseControlNetPipeline, AnimateDiffVideoToVideoPipeline, ) from .audioldm import AudioLDMPipeline from .audioldm2 import ( AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, ) from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .cogvideo import CogVideoXPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, ) from .controlnet_hunyuandit import ( HunyuanDiTControlNetPipeline, ) from .controlnet_sd3 import ( StableDiffusion3ControlNetPipeline, ) from .controlnet_xs import ( StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline, ) from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, IFPipeline, IFSuperResolutionPipeline, ) from .deprecated import ( AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, CycleDiffusionPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionModelEditingPipeline, StableDiffusionParadigmsPipeline, StableDiffusionPix2PixZeroPipeline, VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) from .flux import FluxPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyImg2ImgPipeline, KandinskyInpaintCombinedPipeline, KandinskyInpaintPipeline, KandinskyPipeline, KandinskyPriorPipeline, ) from .kandinsky2_2 import ( KandinskyV22CombinedPipeline, KandinskyV22ControlnetImg2ImgPipeline, KandinskyV22ControlnetPipeline, KandinskyV22Img2ImgCombinedPipeline, KandinskyV22Img2ImgPipeline, KandinskyV22InpaintCombinedPipeline, KandinskyV22InpaintPipeline, KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorPipeline, ) from .kandinsky3 import ( Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, ) from .latent_consistency_models import ( LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, ) from .latent_diffusion import LDMTextToImagePipeline from .latte import LattePipeline from .ledits_pp import ( LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) from .lumina import LuminaText2ImgPipeline from .marigold import ( MarigoldDepthPipeline, MarigoldNormalsPipeline, ) from .musicldm import MusicLDMPipeline from .pag import ( AnimateDiffPAGPipeline, HunyuanDiTPAGPipeline, KolorsPAGPipeline, PixArtSigmaPAGPipeline, StableDiffusion3PAGPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionPAGPipeline, StableDiffusionXLControlNetPAGPipeline, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, ) from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel from .stable_cascade import ( StableCascadeCombinedPipeline, StableCascadeDecoderPipeline, StableCascadePriorPipeline, ) from .stable_diffusion import ( CLIPImageProjection, StableDiffusionDepth2ImgPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionInstructPix2PixPipeline, StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline, StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) from .stable_diffusion_3 import ( StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline, StableDiffusion3Pipeline, ) from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline from .stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline from .stable_diffusion_panorama import StableDiffusionPanoramaPipeline from .stable_diffusion_safe import StableDiffusionPipelineSafe from .stable_diffusion_sag import StableDiffusionSAGPipeline from .stable_diffusion_xl import ( StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLInstructPix2PixPipeline, StableDiffusionXLPipeline, ) from .stable_video_diffusion import StableVideoDiffusionPipeline from .t2i_adapter import ( StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline, ) from .text_to_video_synthesis import ( TextToVideoSDPipeline, TextToVideoZeroPipeline, TextToVideoZeroSDXLPipeline, VideoToVideoSDPipeline, ) from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .unidiffuser import ( ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder, ) from .wuerstchen import ( WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_onnx_objects import * # noqa F403 else: from .onnx_utils import OnnxRuntimeModel try: if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_onnx_objects import * else: from .stable_diffusion import ( OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionPipeline, OnnxStableDiffusionUpscalePipeline, StableDiffusionOnnxPipeline, ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * else: from .stable_diffusion_k_diffusion import ( StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline, ) try: if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_sentencepiece_objects import * else: from .kolors import ( KolorsImg2ImgPipeline, KolorsPipeline, ) try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_flax_objects import * # noqa F403 else: from .pipeline_flax_utils import FlaxDiffusionPipeline try: if not (is_flax_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_flax_and_transformers_objects import * else: from .controlnet import FlaxStableDiffusionControlNetPipeline from .stable_diffusion import ( FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, ) from .stable_diffusion_xl import ( FlaxStableDiffusionXLPipeline, ) try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 else: from .deprecated import ( MidiProcessor, SpectrogramDiffusionPipeline, ) else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/amused/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline, ) _dummy_objects.update( { "AmusedPipeline": AmusedPipeline, "AmusedImg2ImgPipeline": AmusedImg2ImgPipeline, "AmusedInpaintPipeline": AmusedInpaintPipeline, } ) else: _import_structure["pipeline_amused"] = ["AmusedPipeline"] _import_structure["pipeline_amused_img2img"] = ["AmusedImg2ImgPipeline"] _import_structure["pipeline_amused_inpaint"] = ["AmusedInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( AmusedPipeline, ) else: from .pipeline_amused import AmusedPipeline from .pipeline_amused_img2img import AmusedImg2ImgPipeline from .pipeline_amused_inpaint import AmusedInpaintPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/amused/pipeline_amused.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...models import UVit2DModel, VQModel from ...schedulers import AmusedScheduler from ...utils import replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AmusedPipeline >>> pipe = AmusedPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ class AmusedPipeline(DiffusionPipeline): image_processor: VaeImageProcessor vqvae: VQModel tokenizer: CLIPTokenizer text_encoder: CLIPTextModelWithProjection transformer: UVit2DModel scheduler: AmusedScheduler model_cpu_offload_seq = "text_encoder->transformer->vqvae" def __init__( self, vqvae: VQModel, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, transformer: UVit2DModel, scheduler: AmusedScheduler, ): super().__init__() self.register_modules( vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[List[str], str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 12, guidance_scale: float = 10.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[torch.Generator] = None, latents: Optional[torch.IntTensor] = None, prompt_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_encoder_hidden_states: Optional[torch.Tensor] = None, output_type="pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, micro_conditioning_aesthetic_score: int = 6, micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), ): """ The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 16): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 10.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.IntTensor`, *optional*): Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image gneration. If not provided, the starting latents will be completely masked. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. A single vector from the pooled and projected final hidden states. encoder_hidden_states (`torch.Tensor`, *optional*): Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. negative_encoder_hidden_states (`torch.Tensor`, *optional*): Analogous to `encoder_hidden_states` for the positive prompt. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of https://arxiv.org/abs/2307.01952. micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952. temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. Examples: Returns: [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ if (prompt_embeds is not None and encoder_hidden_states is None) or ( prompt_embeds is None and encoder_hidden_states is not None ): raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( negative_prompt_embeds is None and negative_encoder_hidden_states is not None ): raise ValueError( "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" ) if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): raise ValueError("pass only one of `prompt` or `prompt_embeds`") if isinstance(prompt, str): prompt = [prompt] if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] batch_size = batch_size * num_images_per_prompt if height is None: height = self.transformer.config.sample_size * self.vae_scale_factor if width is None: width = self.transformer.config.sample_size * self.vae_scale_factor if prompt_embeds is None: input_ids = self.tokenizer( prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids.to(self._execution_device) outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) prompt_embeds = outputs.text_embeds encoder_hidden_states = outputs.hidden_states[-2] prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) if guidance_scale > 1.0: if negative_prompt_embeds is None: if negative_prompt is None: negative_prompt = [""] * len(prompt) if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] input_ids = self.tokenizer( negative_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids.to(self._execution_device) outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) negative_prompt_embeds = outputs.text_embeds negative_encoder_hidden_states = outputs.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) # Note that the micro conditionings _do_ flip the order of width, height for the original size # and the crop coordinates. This is how it was done in the original code base micro_conds = torch.tensor( [ width, height, micro_conditioning_crop_coord[0], micro_conditioning_crop_coord[1], micro_conditioning_aesthetic_score, ], device=self._execution_device, dtype=encoder_hidden_states.dtype, ) micro_conds = micro_conds.unsqueeze(0) micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: latents = torch.full( shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device ) self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, timestep in enumerate(self.scheduler.timesteps): if guidance_scale > 1.0: model_input = torch.cat([latents] * 2) else: model_input = latents model_output = self.transformer( model_input, micro_conds=micro_conds, pooled_text_emb=prompt_embeds, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ) if guidance_scale > 1.0: uncond_logits, cond_logits = model_output.chunk(2) model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) latents = self.scheduler.step( model_output=model_output, timestep=timestep, sample=latents, generator=generator, ).prev_sample if i == len(self.scheduler.timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, timestep, latents) if output_type == "latent": output = latents else: needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast if needs_upcasting: self.vqvae.float() output = self.vqvae.decode( latents, force_not_quantize=True, shape=( batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor, self.vqvae.config.latent_channels, ), ).sample.clip(0, 1) output = self.image_processor.postprocess(output, output_type) if needs_upcasting: self.vqvae.half() self.maybe_free_model_hooks() if not return_dict: return (output,) return ImagePipelineOutput(output) ================================================ FILE: src/diffusers/pipelines/amused/pipeline_amused_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor from ...models import UVit2DModel, VQModel from ...schedulers import AmusedScheduler from ...utils import replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AmusedImg2ImgPipeline >>> from diffusers.utils import load_image >>> pipe = AmusedImg2ImgPipeline.from_pretrained( ... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "winter mountains" >>> input_image = ( ... load_image( ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg" ... ) ... .resize((512, 512)) ... .convert("RGB") ... ) >>> image = pipe(prompt, input_image).images[0] ``` """ class AmusedImg2ImgPipeline(DiffusionPipeline): image_processor: VaeImageProcessor vqvae: VQModel tokenizer: CLIPTokenizer text_encoder: CLIPTextModelWithProjection transformer: UVit2DModel scheduler: AmusedScheduler model_cpu_offload_seq = "text_encoder->transformer->vqvae" # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter # off the meta device. There should be a way to fix this instead of just not offloading it _exclude_from_cpu_offload = ["vqvae"] def __init__( self, vqvae: VQModel, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, transformer: UVit2DModel, scheduler: AmusedScheduler, ): super().__init__() self.register_modules( vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[List[str], str]] = None, image: PipelineImageInput = None, strength: float = 0.5, num_inference_steps: int = 12, guidance_scale: float = 10.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[torch.Generator] = None, prompt_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_encoder_hidden_states: Optional[torch.Tensor] = None, output_type="pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, micro_conditioning_aesthetic_score: int = 6, micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), ): """ The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. strength (`float`, *optional*, defaults to 0.5): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 12): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 10.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. A single vector from the pooled and projected final hidden states. encoder_hidden_states (`torch.Tensor`, *optional*): Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. negative_encoder_hidden_states (`torch.Tensor`, *optional*): Analogous to `encoder_hidden_states` for the positive prompt. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of https://arxiv.org/abs/2307.01952. micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952. temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. Examples: Returns: [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ if (prompt_embeds is not None and encoder_hidden_states is None) or ( prompt_embeds is None and encoder_hidden_states is not None ): raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( negative_prompt_embeds is None and negative_encoder_hidden_states is not None ): raise ValueError( "pass either both `negative_prompt_embeds` and `negative_encoder_hidden_states` or neither" ) if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): raise ValueError("pass only one of `prompt` or `prompt_embeds`") if isinstance(prompt, str): prompt = [prompt] if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] batch_size = batch_size * num_images_per_prompt if prompt_embeds is None: input_ids = self.tokenizer( prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids.to(self._execution_device) outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) prompt_embeds = outputs.text_embeds encoder_hidden_states = outputs.hidden_states[-2] prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) if guidance_scale > 1.0: if negative_prompt_embeds is None: if negative_prompt is None: negative_prompt = [""] * len(prompt) if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] input_ids = self.tokenizer( negative_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids.to(self._execution_device) outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) negative_prompt_embeds = outputs.text_embeds negative_encoder_hidden_states = outputs.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) image = self.image_processor.preprocess(image) height, width = image.shape[-2:] # Note that the micro conditionings _do_ flip the order of width, height for the original size # and the crop coordinates. This is how it was done in the original code base micro_conds = torch.tensor( [ width, height, micro_conditioning_crop_coord[0], micro_conditioning_crop_coord[1], micro_conditioning_aesthetic_score, ], device=self._execution_device, dtype=encoder_hidden_states.dtype, ) micro_conds = micro_conds.unsqueeze(0) micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) num_inference_steps = int(len(self.scheduler.timesteps) * strength) start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast if needs_upcasting: self.vqvae.float() latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents latents_bsz, channels, latents_height, latents_width = latents.shape latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) latents = self.scheduler.add_noise( latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator ) latents = latents.repeat(num_images_per_prompt, 1, 1) with self.progress_bar(total=num_inference_steps) as progress_bar: for i in range(start_timestep_idx, len(self.scheduler.timesteps)): timestep = self.scheduler.timesteps[i] if guidance_scale > 1.0: model_input = torch.cat([latents] * 2) else: model_input = latents model_output = self.transformer( model_input, micro_conds=micro_conds, pooled_text_emb=prompt_embeds, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ) if guidance_scale > 1.0: uncond_logits, cond_logits = model_output.chunk(2) model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) latents = self.scheduler.step( model_output=model_output, timestep=timestep, sample=latents, generator=generator, ).prev_sample if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, timestep, latents) if output_type == "latent": output = latents else: output = self.vqvae.decode( latents, force_not_quantize=True, shape=( batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor, self.vqvae.config.latent_channels, ), ).sample.clip(0, 1) output = self.image_processor.postprocess(output, output_type) if needs_upcasting: self.vqvae.half() self.maybe_free_model_hooks() if not return_dict: return (output,) return ImagePipelineOutput(output) ================================================ FILE: src/diffusers/pipelines/amused/pipeline_amused_inpaint.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor from ...models import UVit2DModel, VQModel from ...schedulers import AmusedScheduler from ...utils import replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AmusedInpaintPipeline >>> from diffusers.utils import load_image >>> pipe = AmusedInpaintPipeline.from_pretrained( ... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "fall mountains" >>> input_image = ( ... load_image( ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg" ... ) ... .resize((512, 512)) ... .convert("RGB") ... ) >>> mask = ( ... load_image( ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" ... ) ... .resize((512, 512)) ... .convert("L") ... ) >>> pipe(prompt, input_image, mask).images[0].save("out.png") ``` """ class AmusedInpaintPipeline(DiffusionPipeline): image_processor: VaeImageProcessor vqvae: VQModel tokenizer: CLIPTokenizer text_encoder: CLIPTextModelWithProjection transformer: UVit2DModel scheduler: AmusedScheduler model_cpu_offload_seq = "text_encoder->transformer->vqvae" # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter # off the meta device. There should be a way to fix this instead of just not offloading it _exclude_from_cpu_offload = ["vqvae"] def __init__( self, vqvae: VQModel, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, transformer: UVit2DModel, scheduler: AmusedScheduler, ): super().__init__() self.register_modules( vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True, do_resize=True, ) self.scheduler.register_to_config(masking_schedule="linear") @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[List[str], str]] = None, image: PipelineImageInput = None, mask_image: PipelineImageInput = None, strength: float = 1.0, num_inference_steps: int = 12, guidance_scale: float = 10.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[torch.Generator] = None, prompt_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_encoder_hidden_states: Optional[torch.Tensor] = None, output_type="pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, micro_conditioning_aesthetic_score: int = 6, micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), ): """ The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 1)`, or `(H, W)`. strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 16): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 10.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. A single vector from the pooled and projected final hidden states. encoder_hidden_states (`torch.Tensor`, *optional*): Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. negative_encoder_hidden_states (`torch.Tensor`, *optional*): Analogous to `encoder_hidden_states` for the positive prompt. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of https://arxiv.org/abs/2307.01952. micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952. temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. Examples: Returns: [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ if (prompt_embeds is not None and encoder_hidden_states is None) or ( prompt_embeds is None and encoder_hidden_states is not None ): raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( negative_prompt_embeds is None and negative_encoder_hidden_states is not None ): raise ValueError( "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" ) if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): raise ValueError("pass only one of `prompt` or `prompt_embeds`") if isinstance(prompt, str): prompt = [prompt] if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] batch_size = batch_size * num_images_per_prompt if prompt_embeds is None: input_ids = self.tokenizer( prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids.to(self._execution_device) outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) prompt_embeds = outputs.text_embeds encoder_hidden_states = outputs.hidden_states[-2] prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) if guidance_scale > 1.0: if negative_prompt_embeds is None: if negative_prompt is None: negative_prompt = [""] * len(prompt) if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] input_ids = self.tokenizer( negative_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids.to(self._execution_device) outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) negative_prompt_embeds = outputs.text_embeds negative_encoder_hidden_states = outputs.hidden_states[-2] negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) image = self.image_processor.preprocess(image) height, width = image.shape[-2:] # Note that the micro conditionings _do_ flip the order of width, height for the original size # and the crop coordinates. This is how it was done in the original code base micro_conds = torch.tensor( [ width, height, micro_conditioning_crop_coord[0], micro_conditioning_crop_coord[1], micro_conditioning_aesthetic_score, ], device=self._execution_device, dtype=encoder_hidden_states.dtype, ) micro_conds = micro_conds.unsqueeze(0) micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) num_inference_steps = int(len(self.scheduler.timesteps) * strength) start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast if needs_upcasting: self.vqvae.float() latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents latents_bsz, channels, latents_height, latents_width = latents.shape latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) mask = self.mask_processor.preprocess( mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor ) mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device) latents[mask] = self.scheduler.config.mask_token_id starting_mask_ratio = mask.sum() / latents.numel() latents = latents.repeat(num_images_per_prompt, 1, 1) with self.progress_bar(total=num_inference_steps) as progress_bar: for i in range(start_timestep_idx, len(self.scheduler.timesteps)): timestep = self.scheduler.timesteps[i] if guidance_scale > 1.0: model_input = torch.cat([latents] * 2) else: model_input = latents model_output = self.transformer( model_input, micro_conds=micro_conds, pooled_text_emb=prompt_embeds, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, ) if guidance_scale > 1.0: uncond_logits, cond_logits = model_output.chunk(2) model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) latents = self.scheduler.step( model_output=model_output, timestep=timestep, sample=latents, generator=generator, starting_mask_ratio=starting_mask_ratio, ).prev_sample if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, timestep, latents) if output_type == "latent": output = latents else: output = self.vqvae.decode( latents, force_not_quantize=True, shape=( batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor, self.vqvae.config.latent_channels, ), ).sample.clip(0, 1) output = self.image_processor.postprocess(output, output_type) if needs_upcasting: self.vqvae.half() self.maybe_free_model_hooks() if not return_dict: return (output,) return ImagePipelineOutput(output) ================================================ FILE: src/diffusers/pipelines/animatediff/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {"pipeline_output": ["AnimateDiffPipelineOutput"]} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"] _import_structure["pipeline_animatediff_controlnet"] = ["AnimateDiffControlNetPipeline"] _import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"] _import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"] _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_animatediff import AnimateDiffPipeline from .pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline from .pipeline_output import AnimateDiffPipelineOutput else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/animatediff/pipeline_animatediff.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler >>> from diffusers.utils import export_to_gif >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) >>> output = pipe(prompt="A corgi walking in the park") >>> frames = output.frames[0] >>> export_to_gif(frames, "animation.gif") ``` """ class AnimateDiffPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, AnimateDiffFreeNoiseMixin, ): r""" Pipeline for text-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): A [`~transformers.CLIPTokenizer`] to tokenize text. unet ([`UNet2DConditionModel`]): A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. motion_adapter ([`MotionAdapter`]): A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: Union[UNet2DConditionModel, UNetMotionModel], motion_adapter: MotionAdapter, scheduler: Union[ DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() if isinstance(unet, UNet2DConditionModel): unet = UNetMotionModel.from_unet2d(unet, motion_adapter) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, motion_adapter=motion_adapter, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds def decode_latents(self, latents, decode_chunk_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] for i in range(0, latents.shape[0], decode_chunk_size): batch_latents = latents[i : i + decode_chunk_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) video = torch.cat(video) video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) if self.free_noise_enabled: latents = self._prepare_latents_free_noise( batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, num_frames: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], decode_chunk_size: int = 16, **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated video. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated video. num_frames (`int`, *optional*, defaults to 16): The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. decode_chunk_size (`int`, defaults to `16`): The number of frames to decode at a time when calling `decode_latents` method. Examples: Returns: [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_videos_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: latents, timesteps = self._apply_free_init( latents, free_init_iter, num_inference_steps, device, latents.dtype, generator ) self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 9. Post processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return AnimateDiffPipelineOutput(frames=video) ================================================ FILE: src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor from ..controlnet.multicontrolnet import MultiControlNetModel from ..free_init_utils import FreeInitMixin from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import ( ... AnimateDiffControlNetPipeline, ... AutoencoderKL, ... ControlNetModel, ... MotionAdapter, ... LCMScheduler, ... ) >>> from diffusers.utils import export_to_gif, load_video >>> # Additionally, you will need a preprocess videos before they can be used with the ControlNet >>> # HF maintains just the right package for it: `pip install controlnet_aux` >>> from controlnet_aux.processor import ZoeDetector >>> # Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file >>> # Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained() >>> controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16) >>> # We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3) >>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM") >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) >>> pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained( ... "SG161222/Realistic_Vision_V5.1_noVAE", ... motion_adapter=motion_adapter, ... controlnet=controlnet, ... vae=vae, ... ).to(device="cuda", dtype=torch.float16) >>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") >>> pipe.load_lora_weights( ... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora" ... ) >>> pipe.set_adapters(["lcm-lora"], [0.8]) >>> depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda") >>> video = load_video( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif" ... ) >>> conditioning_frames = [] >>> with pipe.progress_bar(total=len(video)) as progress_bar: ... for frame in video: ... conditioning_frames.append(depth_detector(frame)) ... progress_bar.update() >>> prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality" >>> negative_prompt = "bad quality, worst quality" >>> video = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, ... num_frames=len(video), ... num_inference_steps=10, ... guidance_scale=2.0, ... conditioning_frames=conditioning_frames, ... generator=torch.Generator().manual_seed(42), ... ).frames[0] >>> export_to_gif(video, "animatediff_controlnet.gif", fps=8) ``` """ class AnimateDiffControlNetPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, AnimateDiffFreeNoiseMixin, ): r""" Pipeline for text-to-video generation with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): A [`~transformers.CLIPTokenizer`] to tokenize text. unet ([`UNet2DConditionModel`]): A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. motion_adapter ([`MotionAdapter`]): A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["feature_extractor", "image_encoder"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: Union[UNet2DConditionModel, UNetMotionModel], motion_adapter: MotionAdapter, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, feature_extractor: Optional[CLIPImageProcessor] = None, image_encoder: Optional[CLIPVisionModelWithProjection] = None, ): super().__init__() if isinstance(unet, UNet2DConditionModel): unet = UNetMotionModel.from_unet2d(unet, motion_adapter) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, motion_adapter=motion_adapter, controlnet=controlnet, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.control_video_processor = VideoProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents def decode_latents(self, latents, decode_chunk_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] for i in range(0, latents.shape[0], decode_chunk_size): batch_latents = latents[i : i + decode_chunk_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) video = torch.cat(video) video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, num_frames, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, video=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(video, list): raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(video)}") if len(video) != num_frames: raise ValueError(f"Excepted image to have length {num_frames} but got {len(video)=}") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(video, list) or not isinstance(video[0], list): raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(video)=}") if len(video[0]) != num_frames: raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(video[0])=}") if any(len(img) != len(video[0]) for img in video): raise ValueError("All conditioning frame batches for multicontrolnet must be same size") else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] if not isinstance(control_guidance_end, (tuple, list)): control_guidance_end = [control_guidance_end] if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) if self.free_noise_enabled: latents = self._prepare_latents_free_noise( batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_video( self, video, width, height, batch_size, num_videos_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): video = self.control_video_processor.preprocess_video(video, height=height, width=width).to( dtype=torch.float32 ) video = video.permute(0, 2, 1, 3, 4).flatten(0, 1) video_batch_size = video.shape[0] if video_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_videos_per_prompt video = video.repeat_interleave(repeat_by, dim=0) video = video.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: video = torch.cat([video] * 2) return video @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, num_frames: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[PipelineImageInput] = None, conditioning_frames: Optional[List[PipelineImageInput]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], decode_chunk_size: int = 16, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated video. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated video. num_frames (`int`, *optional*, defaults to 16): The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. conditioning_frames (`List[PipelineImageInput]`, *optional*): The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets are specified, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): The ControlNet encoder tries to recognize the content of the input image even if you remove all prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, height=height, width=width, num_frames=num_frames, negative_prompt=negative_prompt, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, video=conditioning_frames, controlnet_conditioning_scale=controlnet_conditioning_scale, control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_videos_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt, self.do_classifier_free_guidance, ) if isinstance(controlnet, ControlNetModel): conditioning_frames = self.prepare_video( video=conditioning_frames, width=width, height=height, batch_size=batch_size * num_videos_per_prompt * num_frames, num_videos_per_prompt=num_videos_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): cond_prepared_videos = [] for frame_ in conditioning_frames: prepared_video = self.prepare_video( video=frame_, width=width, height=height, batch_size=batch_size * num_videos_per_prompt * num_frames, num_videos_per_prompt=num_videos_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) cond_prepared_videos.append(prepared_video) conditioning_frames = cond_prepared_videos else: assert False # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: latents, timesteps = self._apply_free_init( latents, free_init_iter, num_inference_steps, device, latents.dtype, generator ) self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0) if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] control_model_input = torch.transpose(control_model_input, 1, 2) control_model_input = control_model_input.reshape( (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4]) ) down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=conditioning_frames, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ).sample # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # 9. Post processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return AnimateDiffPipelineOutput(frames=video) ================================================ FILE: src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...image_processor import PipelineImageInput from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ImageProjection, MotionAdapter, UNet2DConditionModel, UNetMotionModel from ...models.attention_processor import ( AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from ...utils import ( USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers.models import MotionAdapter >>> from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler >>> from diffusers.utils import export_to_gif >>> adapter = MotionAdapter.from_pretrained( ... "a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16 ... ) >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0" >>> scheduler = DDIMScheduler.from_pretrained( ... model_id, ... subfolder="scheduler", ... clip_sample=False, ... timestep_spacing="linspace", ... beta_schedule="linear", ... steps_offset=1, ... ) >>> pipe = AnimateDiffSDXLPipeline.from_pretrained( ... model_id, ... motion_adapter=adapter, ... scheduler=scheduler, ... torch_dtype=torch.float16, ... variant="fp16", ... ).to("cuda") >>> # enable memory savings >>> pipe.enable_vae_slicing() >>> pipe.enable_vae_tiling() >>> output = pipe( ... prompt="a panda surfing in the ocean, realistic, high quality", ... negative_prompt="low quality, worst quality", ... num_inference_steps=20, ... guidance_scale=8, ... width=1024, ... height=1024, ... num_frames=16, ... ) >>> frames = output.frames[0] >>> export_to_gif(frames, "animation.gif") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class AnimateDiffSDXLPipeline( DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, IPAdapterMixin, FreeInitMixin, ): r""" Pipeline for text-to-video generation using Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: Union[UNet2DConditionModel, UNetMotionModel], motion_adapter: MotionAdapter, scheduler: Union[ DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, ): super().__init__() if isinstance(unet, UNet2DConditionModel): unet = UNetMotionModel.from_unet2d(unet, motion_adapter) self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, motion_adapter=motion_adapter, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_videos_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_videos_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( bs_embed * num_videos_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( bs_embed * num_videos_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, prompt_2, height, width, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, num_frames: int = 16, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders num_frames: The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated video. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated video. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower video quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the video generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.AnimateDiffPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor num_videos_per_prompt = 1 original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_videos_per_prompt=num_videos_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt, self.do_classifier_free_guidance, ) # 7.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 8. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_videos_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: latents, timesteps = self._apply_free_init( latents, free_init_iter, num_inference_steps, device, latents.dtype, generator ) self._num_timesteps = len(timesteps) # 9. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg( noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) progress_bar.update() # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) # 10. Post processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) # 11. Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return AnimateDiffPipelineOutput(frames=video) ================================================ FILE: src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.controlnet_sparsectrl import SparseControlNetModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```python >>> import torch >>> from diffusers import AnimateDiffSparseControlNetPipeline >>> from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel >>> from diffusers.schedulers import DPMSolverMultistepScheduler >>> from diffusers.utils import export_to_gif, load_image >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" >>> motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3" >>> controlnet_id = "guoyww/animatediff-sparsectrl-scribble" >>> lora_adapter_id = "guoyww/animatediff-motion-lora-v1-5-3" >>> vae_id = "stabilityai/sd-vae-ft-mse" >>> device = "cuda" >>> motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device) >>> controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device) >>> vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device) >>> scheduler = DPMSolverMultistepScheduler.from_pretrained( ... model_id, ... subfolder="scheduler", ... beta_schedule="linear", ... algorithm_type="dpmsolver++", ... use_karras_sigmas=True, ... ) >>> pipe = AnimateDiffSparseControlNetPipeline.from_pretrained( ... model_id, ... motion_adapter=motion_adapter, ... controlnet=controlnet, ... vae=vae, ... scheduler=scheduler, ... torch_dtype=torch.float16, ... ).to(device) >>> pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora") >>> pipe.fuse_lora(lora_scale=1.0) >>> prompt = "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality" >>> negative_prompt = "low quality, worst quality, letterboxed" >>> image_files = [ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png", ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png", ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png", ... ] >>> condition_frame_indices = [0, 8, 15] >>> conditioning_frames = [load_image(img_file) for img_file in image_files] >>> video = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, ... num_inference_steps=25, ... conditioning_frames=conditioning_frames, ... controlnet_conditioning_scale=1.0, ... controlnet_frame_indices=condition_frame_indices, ... generator=torch.Generator().manual_seed(1337), ... ).frames[0] >>> export_to_gif(video, "output.gif") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") class AnimateDiffSparseControlNetPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, ): r""" Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933). This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): A [`~transformers.CLIPTokenizer`] to tokenize text. unet ([`UNet2DConditionModel`]): A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. motion_adapter ([`MotionAdapter`]): A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: Union[UNet2DConditionModel, UNetMotionModel], motion_adapter: MotionAdapter, controlnet: SparseControlNetModel, scheduler: KarrasDiffusionSchedulers, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() if isinstance(unet, UNet2DConditionModel): unet = UNetMotionModel.from_unet2d(unet, motion_adapter) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, motion_adapter=motion_adapter, controlnet=controlnet, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, image=None, controlnet_conditioning_scale: float = 1.0, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) # check `image` if ( isinstance(self.controlnet, SparseControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, SparseControlNetModel) ): if isinstance(image, list): for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: self.check_image(image, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, SparseControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, SparseControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") else: assert False # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_image(self, image, width, height, device, dtype): image = self.control_image_processor.preprocess(image, height=height, width=width) controlnet_images = image.unsqueeze(0).to(device, dtype) batch_size, num_frames, channels, height, width = controlnet_images.shape # TODO: remove below line assert controlnet_images.min() >= 0 and controlnet_images.max() <= 1 if self.controlnet.use_simplified_condition_embedding: controlnet_images = controlnet_images.reshape(batch_size * num_frames, channels, height, width) controlnet_images = 2 * controlnet_images - 1 conditioning_frames = retrieve_latents(self.vae.encode(controlnet_images)) * self.vae.config.scaling_factor conditioning_frames = conditioning_frames.reshape( batch_size, num_frames, 4, height // self.vae_scale_factor, width // self.vae_scale_factor ) else: conditioning_frames = controlnet_images conditioning_frames = conditioning_frames.permute(0, 2, 1, 3, 4) # [b, c, f, h, w] return conditioning_frames def prepare_sparse_control_conditioning( self, conditioning_frames: torch.Tensor, num_frames: int, controlnet_frame_indices: int, device: torch.device, dtype: torch.dtype, ) -> Tuple[torch.Tensor, torch.Tensor]: assert conditioning_frames.shape[2] >= len(controlnet_frame_indices) batch_size, channels, _, height, width = conditioning_frames.shape controlnet_cond = torch.zeros((batch_size, channels, num_frames, height, width), dtype=dtype, device=device) controlnet_cond_mask = torch.zeros((batch_size, 1, num_frames, height, width), dtype=dtype, device=device) controlnet_cond[:, :, controlnet_frame_indices] = conditioning_frames[:, :, : len(controlnet_frame_indices)] controlnet_cond_mask[:, :, controlnet_frame_indices] = 1 return controlnet_cond, controlnet_cond_mask @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_frames: int = 16, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, conditioning_frames: Optional[List[PipelineImageInput]] = None, output_type: str = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, controlnet_frame_indices: List[int] = [0], guess_mode: bool = False, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated video. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated video. num_frames (`int`, *optional*, defaults to 16): The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. conditioning_frames (`List[PipelineImageInput]`, *optional*): The SparseControlNet input to provide guidance to the `unet` for generation. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. controlnet_frame_indices (`List[int]`): The indices where the conditioning frames must be applied for generation. Multiple frames can be provided to guide the model to generate similar structure outputs, where the `unet` can "fill-in-the-gaps" for interpolation videos, or a single frame could be provided for general expected structure. Must have the same length as `conditioning_frames`. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, height=height, width=width, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ip_adapter_image=ip_adapter_image, ip_adapter_image_embeds=ip_adapter_image_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, image=conditioning_frames, controlnet_conditioning_scale=controlnet_conditioning_scale, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, SparseControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_videos_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare IP-Adapter embeddings if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt, self.do_classifier_free_guidance, ) # 5. Prepare controlnet conditioning conditioning_frames = self.prepare_image(conditioning_frames, width, height, device, controlnet.dtype) controlnet_cond, controlnet_cond_mask = self.prepare_sparse_control_conditioning( conditioning_frames, num_frames, controlnet_frame_indices, device, controlnet.dtype ) # 6. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 7. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: latents, timesteps = self._apply_free_init( latents, free_init_iter, num_inference_steps, device, latents.dtype, generator ) self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 10. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if guess_mode and self.do_classifier_free_guidance: # Infer SparseControlNetModel only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=controlnet_cond, conditioning_mask=controlnet_cond_mask, conditioning_scale=controlnet_conditioning_scale, guess_mode=guess_mode, return_dict=False, ) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ).sample # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # 11. Post processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 12. Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return AnimateDiffPipelineOutput(frames=video) ================================================ FILE: src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import imageio >>> import requests >>> import torch >>> from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter >>> from diffusers.utils import export_to_gif >>> from io import BytesIO >>> from PIL import Image >>> adapter = MotionAdapter.from_pretrained( ... "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16 ... ) >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained( ... "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter ... ).to("cuda") >>> pipe.scheduler = DDIMScheduler( ... beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace" ... ) >>> def load_video(file_path: str): ... images = [] ... if file_path.startswith(("http://", "https://")): ... # If the file_path is a URL ... response = requests.get(file_path) ... response.raise_for_status() ... content = BytesIO(response.content) ... vid = imageio.get_reader(content) ... else: ... # Assuming it's a local file path ... vid = imageio.get_reader(file_path) ... for frame in vid: ... pil_image = Image.fromarray(frame) ... images.append(pil_image) ... return images >>> video = load_video( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif" ... ) >>> output = pipe( ... video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5 ... ) >>> frames = output.frames[0] >>> export_to_gif(frames, "animation.gif") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class AnimateDiffVideoToVideoPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, AnimateDiffFreeNoiseMixin, ): r""" Pipeline for video-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): A [`~transformers.CLIPTokenizer`] to tokenize text. unet ([`UNet2DConditionModel`]): A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. motion_adapter ([`MotionAdapter`]): A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, motion_adapter: MotionAdapter, scheduler: Union[ DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() if isinstance(unet, UNet2DConditionModel): unet = UNetMotionModel.from_unet2d(unet, motion_adapter) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, motion_adapter=motion_adapter, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor: latents = [] for i in range(0, len(video), decode_chunk_size): batch_video = video[i : i + decode_chunk_size] batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator) latents.append(batch_video) return torch.cat(latents) # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents def decode_latents(self, latents, decode_chunk_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] for i in range(0, latents.shape[0], decode_chunk_size): batch_latents = latents[i : i + decode_chunk_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) video = torch.cat(video) video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, height, width, video=None, latents=None, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if video is not None and latents is not None: raise ValueError("Only one of `video` or `latents` should be provided") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def get_timesteps(self, num_inference_steps, timesteps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents( self, video, height, width, num_channels_latents, batch_size, timestep, dtype, device, generator, latents=None, decode_chunk_size: int = 16, ): if latents is None: num_frames = video.shape[1] else: num_frames = latents.shape[2] shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: video = video.float() self.vae.to(dtype=torch.float32) if isinstance(generator, list): if len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) init_latents = [ self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0) for i in range(batch_size) ] else: init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video] init_latents = torch.cat(init_latents, dim=0) # restore vae to original dtype if self.vae.config.force_upcast: self.vae.to(dtype) init_latents = init_latents.to(dtype) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size error_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Please make sure to update your script to pass as many initial images as text prompts" ) raise ValueError(error_message) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4) else: if shape != latents.shape: # [B, C, F, H, W] raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}") latents = latents.to(device, dtype=dtype) return latents @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() def __call__( self, video: List[List[PipelineImageInput]] = None, prompt: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, guidance_scale: float = 7.5, strength: float = 0.8, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], decode_chunk_size: int = 16, ): r""" The call function to the pipeline for generation. Args: video (`List[PipelineImageInput]`): The input video to condition the generation on. Must be a list of images/frames of the video. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated video. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. strength (`float`, *optional*, defaults to 0.8): Higher strength leads to more differences between original video and generated video. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. decode_chunk_size (`int`, defaults to `16`): The number of frames to decode at a time when calling `decode_latents` method. Examples: Returns: [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, strength=strength, height=height, width=width, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, video=video, latents=latents, ip_adapter_image=ip_adapter_image, ip_adapter_image_embeds=ip_adapter_image_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_videos_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) # 5. Prepare latent variables if latents is None: video = self.video_processor.preprocess_video(video, height=height, width=width) # Move the number of frames before the number of channels. video = video.permute(0, 2, 1, 3, 4) video = video.to(device=device, dtype=prompt_embeds.dtype) num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( video=video, height=height, width=width, num_channels_latents=num_channels_latents, batch_size=batch_size * num_videos_per_prompt, timestep=latent_timestep, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=latents, decode_chunk_size=decode_chunk_size, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: latents, timesteps = self._apply_free_init( latents, free_init_iter, num_inference_steps, device, latents.dtype, generator ) num_inference_steps = len(timesteps) # make sure to readjust timesteps based on strength timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # 9. Post-processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return AnimateDiffPipelineOutput(frames=video) ================================================ FILE: src/diffusers/pipelines/animatediff/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Union import numpy as np import PIL.Image import torch from ...utils import BaseOutput @dataclass class AnimateDiffPipelineOutput(BaseOutput): r""" Output class for AnimateDiff pipelines. Args: frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape `(batch_size, num_frames, channels, height, width)` """ frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] ================================================ FILE: src/diffusers/pipelines/audioldm/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_transformers_available, is_transformers_version, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( AudioLDMPipeline, ) _dummy_objects.update({"AudioLDMPipeline": AudioLDMPipeline}) else: _import_structure["pipeline_audioldm"] = ["AudioLDMPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( AudioLDMPipeline, ) else: from .pipeline_audioldm import AudioLDMPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/audioldm/pipeline_audioldm.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch import torch.nn.functional as F from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import AudioLDMPipeline >>> import torch >>> import scipy >>> repo_id = "cvssp/audioldm-s-full-v2" >>> pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs" >>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0] >>> # save the audio sample as a .wav file >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio) ``` """ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin): r""" Pipeline for text-to-audio generation using AudioLDM. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.ClapTextModelWithProjection`]): Frozen text-encoder (`ClapTextModelWithProjection`, specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. tokenizer ([`PreTrainedTokenizer`]): A [`~transformers.RobertaTokenizer`] to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded audio latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. vocoder ([`~transformers.SpeechT5HifiGan`]): Vocoder of class `SpeechT5HifiGan`. """ model_cpu_offload_seq = "text_encoder->unet->vae" def __init__( self, vae: AutoencoderKL, text_encoder: ClapTextModelWithProjection, tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, vocoder: SpeechT5HifiGan, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, vocoder=vocoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def _encode_prompt( self, prompt, device, num_waveforms_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device (`torch.device`): torch device num_waveforms_per_prompt (`int`): number of waveforms that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the audio generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLAP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask.to(device), ) prompt_embeds = prompt_embeds.text_embeds # additional L_2 normalization over each hidden-state prompt_embeds = F.normalize(prompt_embeds, dim=-1) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) ( bs_embed, seq_len, ) = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt) prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_input.input_ids.to(device) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input_ids, attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds.text_embeds # additional L_2 normalization over each hidden-state negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents mel_spectrogram = self.vae.decode(latents).sample return mel_spectrogram def mel_spectrogram_to_waveform(self, mel_spectrogram): if mel_spectrogram.dim() == 4: mel_spectrogram = mel_spectrogram.squeeze(1) waveform = self.vocoder(mel_spectrogram) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 waveform = waveform.cpu().float() return waveform # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, audio_length_in_s, vocoder_upsample_factor, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor if audio_length_in_s < min_audio_length_in_s: raise ValueError( f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " f"is {audio_length_in_s}." ) if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: raise ValueError( f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " f"{self.vae_scale_factor}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(self.vocoder.config.model_in_dim) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, audio_length_in_s: Optional[float] = None, num_inference_steps: int = 10, guidance_scale: float = 2.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_waveforms_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, output_type: Optional[str] = "np", ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. audio_length_in_s (`int`, *optional*, defaults to 5.12): The length of the generated audio sample in seconds. num_inference_steps (`int`, *optional*, defaults to 10): The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 2.5): A higher guidance scale value encourages the model to generate audio that is closely linked to the text `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in audio generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_waveforms_per_prompt (`int`, *optional*, defaults to 1): The number of waveforms to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated image. Choose between `"np"` to return a NumPy `np.ndarray` or `"pt"` to return a PyTorch `torch.Tensor` object. Examples: Returns: [`~pipelines.AudioPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated audio. """ # 0. Convert audio input length from seconds to spectrogram height vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate if audio_length_in_s is None: audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor height = int(audio_length_in_s / vocoder_upsample_factor) original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) if height % self.vae_scale_factor != 0: height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor logger.info( f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " f"denoising process." ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, audio_length_in_s, vocoder_upsample_factor, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_waveforms_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_waveforms_per_prompt, num_channels_latents, height, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=None, class_labels=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 8. Post-processing mel_spectrogram = self.decode_latents(latents) audio = self.mel_spectrogram_to_waveform(mel_spectrogram) audio = audio[:, :original_waveform_length] if output_type == "np": audio = audio.numpy() if not return_dict: return (audio,) return AudioPipelineOutput(audios=audio) ================================================ FILE: src/diffusers/pipelines/audioldm2/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, is_transformers_version, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["modeling_audioldm2"] = ["AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel"] _import_structure["pipeline_audioldm2"] = ["AudioLDM2Pipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel from .pipeline_audioldm2 import AudioLDM2Pipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/audioldm2/modeling_audioldm2.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...models.activations import get_activation from ...models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) from ...models.embeddings import ( TimestepEmbedding, Timesteps, ) from ...models.modeling_utils import ModelMixin from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from ...models.transformers.transformer_2d import Transformer2DModel from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D from ...models.unets.unet_2d_condition import UNet2DConditionOutput from ...utils import BaseOutput, is_torch_version, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name def add_special_tokens(hidden_states, attention_mask, sos_token, eos_token): batch_size = hidden_states.shape[0] if attention_mask is not None: # Add two more steps to attn mask new_attn_mask_step = attention_mask.new_ones((batch_size, 1)) attention_mask = torch.concat([new_attn_mask_step, attention_mask, new_attn_mask_step], dim=-1) # Add the SOS / EOS tokens at the start / end of the sequence respectively sos_token = sos_token.expand(batch_size, 1, -1) eos_token = eos_token.expand(batch_size, 1, -1) hidden_states = torch.concat([sos_token, hidden_states, eos_token], dim=1) return hidden_states, attention_mask @dataclass class AudioLDM2ProjectionModelOutput(BaseOutput): """ Args: Class for AudioLDM2 projection layer's outputs. hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text encoders and subsequently concatenating them together. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks for the two text encoders together. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. """ hidden_states: torch.Tensor attention_mask: Optional[torch.LongTensor] = None class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin): """ A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with `_1` refers to that corresponding to the second text encoder. Otherwise, it is from the first. Args: text_encoder_dim (`int`): Dimensionality of the text embeddings from the first text encoder (CLAP). text_encoder_1_dim (`int`): Dimensionality of the text embeddings from the second text encoder (T5 or VITS). langauge_model_dim (`int`): Dimensionality of the text embeddings from the language model (GPT2). """ @register_to_config def __init__( self, text_encoder_dim, text_encoder_1_dim, langauge_model_dim, use_learned_position_embedding=None, max_seq_length=None, ): super().__init__() # additional projection layers for each text encoder self.projection = nn.Linear(text_encoder_dim, langauge_model_dim) self.projection_1 = nn.Linear(text_encoder_1_dim, langauge_model_dim) # learnable SOS / EOS token embeddings for each text encoder self.sos_embed = nn.Parameter(torch.ones(langauge_model_dim)) self.eos_embed = nn.Parameter(torch.ones(langauge_model_dim)) self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) self.use_learned_position_embedding = use_learned_position_embedding # learable positional embedding for vits encoder if self.use_learned_position_embedding is not None: self.learnable_positional_embedding = torch.nn.Parameter( torch.zeros((1, text_encoder_1_dim, max_seq_length)) ) def forward( self, hidden_states: Optional[torch.Tensor] = None, hidden_states_1: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None, attention_mask_1: Optional[torch.LongTensor] = None, ): hidden_states = self.projection(hidden_states) hidden_states, attention_mask = add_special_tokens( hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed ) # Add positional embedding for Vits hidden state if self.use_learned_position_embedding is not None: hidden_states_1 = (hidden_states_1.permute(0, 2, 1) + self.learnable_positional_embedding).permute(0, 2, 1) hidden_states_1 = self.projection_1(hidden_states_1) hidden_states_1, attention_mask_1 = add_special_tokens( hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1 ) # concatenate clap and t5 text encoding hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1) # concatenate attention masks if attention_mask is None and attention_mask_1 is not None: attention_mask = attention_mask_1.new_ones((hidden_states[:2])) elif attention_mask is not None and attention_mask_1 is None: attention_mask_1 = attention_mask.new_ones((hidden_states_1[:2])) if attention_mask is not None and attention_mask_1 is not None: attention_mask = torch.cat([attention_mask, attention_mask_1], dim=-1) else: attention_mask = None return AudioLDM2ProjectionModelOutput( hidden_states=hidden_states, attention_mask=attention_mask, ) class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up to two cross-attention embeddings, `encoder_hidden_states` and `encoder_hidden_states_1`. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can only be `UNetMidBlock2DCrossAttn` for AudioLDM2. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention (`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use `positional`.") self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) else: raise ValueError( f"unknown mid_block_type : {mid_block_type}. Should be `UNetMidBlock2DCrossAttn` for AudioLDM2." ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, encoder_hidden_states_1: Optional[torch.Tensor] = None, encoder_attention_mask_1: Optional[torch.Tensor] = None, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`AudioLDM2UNet2DConditionModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. encoder_hidden_states_1 (`torch.Tensor`, *optional*): A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be used to condition the model on a different set of embeddings to `encoder_hidden_states`. encoder_attention_mask_1 (`torch.Tensor`, *optional*): A cross-attention mask of shape `(batch, sequence_length_2)` is applied to `encoder_hidden_states_1`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. Returns: [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) if encoder_attention_mask_1 is not None: encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0 encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) # 2. pre-process sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1, ) # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, encoder_hidden_states_1=encoder_hidden_states_1, encoder_attention_mask_1=encoder_attention_mask_1, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": return DownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": return UpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{up_block_type} does not exist.") class CrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: raise ValueError( "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" ) self.cross_attention_dim = cross_attention_dim for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) for j in range(len(cross_attention_dim)): attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim[j], norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, double_self_attention=True if cross_attention_dim[j] is None else False, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states_1: Optional[torch.Tensor] = None, encoder_attention_mask_1: Optional[torch.Tensor] = None, ): output_states = () num_layers = len(self.resnets) num_attention_per_layer = len(self.attentions) // num_layers encoder_hidden_states_1 = ( encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states ) encoder_attention_mask_1 = ( encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask ) for i in range(num_layers): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.resnets[i]), hidden_states, temb, **ckpt_kwargs, ) for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states forward_encoder_attention_mask = encoder_attention_mask elif cross_attention_dim is not None and idx > 1: forward_encoder_hidden_states = encoder_hidden_states_1 forward_encoder_attention_mask = encoder_attention_mask_1 else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), hidden_states, forward_encoder_hidden_states, None, # timestep None, # class_labels cross_attention_kwargs, attention_mask, forward_encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = self.resnets[i](hidden_states, temb) for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states forward_encoder_attention_mask = encoder_attention_mask elif cross_attention_dim is not None and idx > 1: forward_encoder_hidden_states = encoder_hidden_states_1 forward_encoder_attention_mask = encoder_attention_mask_1 else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None hidden_states = self.attentions[i * num_attention_per_layer + idx]( hidden_states, attention_mask=attention_mask, encoder_hidden_states=forward_encoder_hidden_states, encoder_attention_mask=forward_encoder_attention_mask, return_dict=False, )[0] output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class UNetMidBlock2DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, use_linear_projection=False, upcast_attention=False, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: raise ValueError( "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" ) self.cross_attention_dim = cross_attention_dim # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] for i in range(num_layers): for j in range(len(cross_attention_dim)): attentions.append( Transformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim[j], norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, double_self_attention=True if cross_attention_dim[j] is None else False, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states_1: Optional[torch.Tensor] = None, encoder_attention_mask_1: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) num_attention_per_layer = len(self.attentions) // (len(self.resnets) - 1) encoder_hidden_states_1 = ( encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states ) encoder_attention_mask_1 = ( encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask ) for i in range(len(self.resnets[1:])): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states forward_encoder_attention_mask = encoder_attention_mask elif cross_attention_dim is not None and idx > 1: forward_encoder_hidden_states = encoder_hidden_states_1 forward_encoder_attention_mask = encoder_attention_mask_1 else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), hidden_states, forward_encoder_hidden_states, None, # timestep None, # class_labels cross_attention_kwargs, attention_mask, forward_encoder_attention_mask, **ckpt_kwargs, )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.resnets[i + 1]), hidden_states, temb, **ckpt_kwargs, ) else: for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states forward_encoder_attention_mask = encoder_attention_mask elif cross_attention_dim is not None and idx > 1: forward_encoder_hidden_states = encoder_hidden_states_1 forward_encoder_attention_mask = encoder_attention_mask_1 else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None hidden_states = self.attentions[i * num_attention_per_layer + idx]( hidden_states, attention_mask=attention_mask, encoder_hidden_states=forward_encoder_hidden_states, encoder_attention_mask=forward_encoder_attention_mask, return_dict=False, )[0] hidden_states = self.resnets[i + 1](hidden_states, temb) return hidden_states class CrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: raise ValueError( "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" ) self.cross_attention_dim = cross_attention_dim for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) for j in range(len(cross_attention_dim)): attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim[j], norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, double_self_attention=True if cross_attention_dim[j] is None else False, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states_1: Optional[torch.Tensor] = None, encoder_attention_mask_1: Optional[torch.Tensor] = None, ): num_layers = len(self.resnets) num_attention_per_layer = len(self.attentions) // num_layers encoder_hidden_states_1 = ( encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states ) encoder_attention_mask_1 = ( encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask ) for i in range(num_layers): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.resnets[i]), hidden_states, temb, **ckpt_kwargs, ) for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states forward_encoder_attention_mask = encoder_attention_mask elif cross_attention_dim is not None and idx > 1: forward_encoder_hidden_states = encoder_hidden_states_1 forward_encoder_attention_mask = encoder_attention_mask_1 else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False), hidden_states, forward_encoder_hidden_states, None, # timestep None, # class_labels cross_attention_kwargs, attention_mask, forward_encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = self.resnets[i](hidden_states, temb) for idx, cross_attention_dim in enumerate(self.cross_attention_dim): if cross_attention_dim is not None and idx <= 1: forward_encoder_hidden_states = encoder_hidden_states forward_encoder_attention_mask = encoder_attention_mask elif cross_attention_dim is not None and idx > 1: forward_encoder_hidden_states = encoder_hidden_states_1 forward_encoder_attention_mask = encoder_attention_mask_1 else: forward_encoder_hidden_states = None forward_encoder_attention_mask = None hidden_states = self.attentions[i * num_attention_per_layer + idx]( hidden_states, attention_mask=attention_mask, encoder_hidden_states=forward_encoder_hidden_states, encoder_attention_mask=forward_encoder_attention_mask, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states ================================================ FILE: src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py ================================================ # Copyright 2024 CVSSP, ByteDance and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch from transformers import ( ClapFeatureExtractor, ClapModel, GPT2Model, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan, T5EncoderModel, T5Tokenizer, T5TokenizerFast, VitsModel, VitsTokenizer, ) from ...models import AutoencoderKL from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, is_librosa_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel if is_librosa_available(): import librosa logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import scipy >>> import torch >>> from diffusers import AudioLDM2Pipeline >>> repo_id = "cvssp/audioldm2" >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> # define the prompts >>> prompt = "The sound of a hammer hitting a wooden surface." >>> negative_prompt = "Low quality." >>> # set the seed for generator >>> generator = torch.Generator("cuda").manual_seed(0) >>> # run the generation >>> audio = pipe( ... prompt, ... negative_prompt=negative_prompt, ... num_inference_steps=200, ... audio_length_in_s=10.0, ... num_waveforms_per_prompt=3, ... generator=generator, ... ).audios >>> # save the best audio sample (index 0) as a .wav file >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0]) ``` ``` #Using AudioLDM2 for Text To Speech >>> import scipy >>> import torch >>> from diffusers import AudioLDM2Pipeline >>> repo_id = "anhnct/audioldm2_gigaspeech" >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> # define the prompts >>> prompt = "A female reporter is speaking" >>> transcript = "wish you have a good day" >>> # set the seed for generator >>> generator = torch.Generator("cuda").manual_seed(0) >>> # run the generation >>> audio = pipe( ... prompt, ... transcription=transcript, ... num_inference_steps=200, ... audio_length_in_s=10.0, ... num_waveforms_per_prompt=2, ... generator=generator, ... max_new_tokens=512, #Must set max_new_tokens equa to 512 for TTS ... ).audios >>> # save the best audio sample (index 0) as a .wav file >>> scipy.io.wavfile.write("tts.wav", rate=16000, data=audio[0]) ``` """ def prepare_inputs_for_generation( inputs_embeds, attention_mask=None, past_key_values=None, **kwargs, ): if past_key_values is not None: # only last token for inputs_embeds if past is defined in kwargs inputs_embeds = inputs_embeds[:, -1:] return { "inputs_embeds": inputs_embeds, "attention_mask": attention_mask, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), } class AudioLDM2Pipeline(DiffusionPipeline): r""" Pipeline for text-to-audio generation using AudioLDM2. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.ClapModel`]): First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model [CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection), specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to rank generated waveforms against the text prompt by computing similarity scores. text_encoder_2 ([`~transformers.T5EncoderModel`, `~transformers.VitsModel`]): Second frozen text-encoder. AudioLDM2 uses the encoder of [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant. Second frozen text-encoder use for TTS. AudioLDM2 uses the encoder of [Vits](https://huggingface.co/docs/transformers/model_doc/vits#transformers.VitsModel). projection_model ([`AudioLDM2ProjectionModel`]): A trained model used to linearly project the hidden-states from the first and second text encoder models and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are concatenated to give the input to the language model. A Learned Position Embedding for the Vits hidden-states language_model ([`~transformers.GPT2Model`]): An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected outputs from the two text encoders. tokenizer ([`~transformers.RobertaTokenizer`]): Tokenizer to tokenize text for the first frozen text-encoder. tokenizer_2 ([`~transformers.T5Tokenizer`, `~transformers.VitsTokenizer`]): Tokenizer to tokenize text for the second frozen text-encoder. feature_extractor ([`~transformers.ClapFeatureExtractor`]): Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded audio latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. vocoder ([`~transformers.SpeechT5HifiGan`]): Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform. """ def __init__( self, vae: AutoencoderKL, text_encoder: ClapModel, text_encoder_2: Union[T5EncoderModel, VitsModel], projection_model: AudioLDM2ProjectionModel, language_model: GPT2Model, tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer], feature_extractor: ClapFeatureExtractor, unet: AudioLDM2UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, vocoder: SpeechT5HifiGan, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, projection_model=projection_model, language_model=language_model, tokenizer=tokenizer, tokenizer_2=tokenizer_2, feature_extractor=feature_extractor, unet=unet, scheduler=scheduler, vocoder=vocoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) model_sequence = [ self.text_encoder.text_model, self.text_encoder.text_projection, self.text_encoder_2, self.projection_model, self.language_model, self.unet, self.vae, self.vocoder, self.text_encoder, ] hook = None for cpu_offloaded_model in model_sequence: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook def generate_language_model( self, inputs_embeds: torch.Tensor = None, max_new_tokens: int = 8, **model_kwargs, ): """ Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs. Parameters: inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): The sequence used as a prompt for the generation. max_new_tokens (`int`): Number of new tokens to generate. model_kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward` function of the model. Return: `inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): The sequence of generated hidden-states. """ max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs) for _ in range(max_new_tokens): # prepare model inputs model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs) # forward pass to get next hidden states output = self.language_model(**model_inputs, return_dict=True) next_hidden_states = output.last_hidden_state # Update the model input inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1) # Update generated hidden states, model inputs, and length for next step model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs) return inputs_embeds[:, -max_new_tokens:, :] def encode_prompt( self, prompt, device, num_waveforms_per_prompt, do_classifier_free_guidance, transcription=None, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, generated_prompt_embeds: Optional[torch.Tensor] = None, negative_generated_prompt_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None, negative_attention_mask: Optional[torch.LongTensor] = None, max_new_tokens: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded transcription (`str` or `List[str]`): transcription of text to speech device (`torch.device`): torch device num_waveforms_per_prompt (`int`): number of waveforms that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the audio generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from `negative_prompt` input argument. generated_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_generated_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from `negative_prompt` input argument. attention_mask (`torch.LongTensor`, *optional*): Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will be computed from `prompt` input argument. negative_attention_mask (`torch.LongTensor`, *optional*): Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention mask will be computed from `negative_prompt` input argument. max_new_tokens (`int`, *optional*, defaults to None): The number of new tokens to generate with the GPT2 language model. Returns: prompt_embeds (`torch.Tensor`): Text embeddings from the Flan T5 model. attention_mask (`torch.LongTensor`): Attention mask to be applied to the `prompt_embeds`. generated_prompt_embeds (`torch.Tensor`): Text embeddings generated from the GPT2 langauge model. Example: ```python >>> import scipy >>> import torch >>> from diffusers import AudioLDM2Pipeline >>> repo_id = "cvssp/audioldm2" >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> # Get text embedding vectors >>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt( ... prompt="Techno music with a strong, upbeat tempo and high melodic riffs", ... device="cuda", ... do_classifier_free_guidance=True, ... ) >>> # Pass text embeddings to pipeline for text-conditional audio generation >>> audio = pipe( ... prompt_embeds=prompt_embeds, ... attention_mask=attention_mask, ... generated_prompt_embeds=generated_prompt_embeds, ... num_inference_steps=200, ... audio_length_in_s=10.0, ... ).audios[0] >>> # save generated audio sample >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio) ```""" if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] is_vits_text_encoder = isinstance(self.text_encoder_2, VitsModel) if is_vits_text_encoder: text_encoders = [self.text_encoder, self.text_encoder_2.text_encoder] else: text_encoders = [self.text_encoder, self.text_encoder_2] if prompt_embeds is None: prompt_embeds_list = [] attention_mask_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): use_prompt = isinstance( tokenizer, (RobertaTokenizer, RobertaTokenizerFast, T5Tokenizer, T5TokenizerFast) ) text_inputs = tokenizer( prompt if use_prompt else transcription, padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer)) else True, max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( f"The following part of your input was truncated because {text_encoder.config.model_type} can " f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids.to(device) attention_mask = attention_mask.to(device) if text_encoder.config.model_type == "clap": prompt_embeds = text_encoder.get_text_features( text_input_ids, attention_mask=attention_mask, ) # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size) prompt_embeds = prompt_embeds[:, None, :] # make sure that we attend to this single hidden-state attention_mask = attention_mask.new_ones((batch_size, 1)) elif is_vits_text_encoder: # Add end_token_id and attention mask in the end of sequence phonemes for text_input_id, text_attention_mask in zip(text_input_ids, attention_mask): for idx, phoneme_id in enumerate(text_input_id): if phoneme_id == 0: text_input_id[idx] = 182 text_attention_mask[idx] = 1 break prompt_embeds = text_encoder( text_input_ids, attention_mask=attention_mask, padding_mask=attention_mask.unsqueeze(-1) ) prompt_embeds = prompt_embeds[0] else: prompt_embeds = text_encoder( text_input_ids, attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds_list.append(prompt_embeds) attention_mask_list.append(attention_mask) projection_output = self.projection_model( hidden_states=prompt_embeds_list[0], hidden_states_1=prompt_embeds_list[1], attention_mask=attention_mask_list[0], attention_mask_1=attention_mask_list[1], ) projected_prompt_embeds = projection_output.hidden_states projected_attention_mask = projection_output.attention_mask generated_prompt_embeds = self.generate_language_model( projected_prompt_embeds, attention_mask=projected_attention_mask, max_new_tokens=max_new_tokens, ) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) attention_mask = ( attention_mask.to(device=device) if attention_mask is not None else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device) ) generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device) bs_embed, seq_len, hidden_size = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size) # duplicate attention mask for each generation per prompt attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt) attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len) bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape # duplicate generated embeddings for each generation per prompt, using mps friendly method generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) generated_prompt_embeds = generated_prompt_embeds.view( bs_embed * num_waveforms_per_prompt, seq_len, hidden_size ) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt negative_prompt_embeds_list = [] negative_attention_mask_list = [] max_length = prompt_embeds.shape[1] for tokenizer, text_encoder in zip(tokenizers, text_encoders): uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=tokenizer.model_max_length if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer)) else max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_input.input_ids.to(device) negative_attention_mask = uncond_input.attention_mask.to(device) if text_encoder.config.model_type == "clap": negative_prompt_embeds = text_encoder.get_text_features( uncond_input_ids, attention_mask=negative_attention_mask, ) # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size) negative_prompt_embeds = negative_prompt_embeds[:, None, :] # make sure that we attend to this single hidden-state negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1)) elif is_vits_text_encoder: negative_prompt_embeds = torch.zeros( batch_size, tokenizer.model_max_length, text_encoder.config.hidden_size, ).to(dtype=self.text_encoder_2.dtype, device=device) negative_attention_mask = torch.zeros(batch_size, tokenizer.model_max_length).to( dtype=self.text_encoder_2.dtype, device=device ) else: negative_prompt_embeds = text_encoder( uncond_input_ids, attention_mask=negative_attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_attention_mask_list.append(negative_attention_mask) projection_output = self.projection_model( hidden_states=negative_prompt_embeds_list[0], hidden_states_1=negative_prompt_embeds_list[1], attention_mask=negative_attention_mask_list[0], attention_mask_1=negative_attention_mask_list[1], ) negative_projected_prompt_embeds = projection_output.hidden_states negative_projected_attention_mask = projection_output.attention_mask negative_generated_prompt_embeds = self.generate_language_model( negative_projected_prompt_embeds, attention_mask=negative_projected_attention_mask, max_new_tokens=max_new_tokens, ) if do_classifier_free_guidance: seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) negative_attention_mask = ( negative_attention_mask.to(device=device) if negative_attention_mask is not None else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device) ) negative_generated_prompt_embeds = negative_generated_prompt_embeds.to( dtype=self.language_model.dtype, device=device ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1) # duplicate unconditional attention mask for each generation per prompt negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt) negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len) # duplicate unconditional generated embeddings for each generation per prompt seq_len = negative_generated_prompt_embeds.shape[1] negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) negative_generated_prompt_embeds = negative_generated_prompt_embeds.view( batch_size * num_waveforms_per_prompt, seq_len, -1 ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) attention_mask = torch.cat([negative_attention_mask, attention_mask]) generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds]) return prompt_embeds, attention_mask, generated_prompt_embeds # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform def mel_spectrogram_to_waveform(self, mel_spectrogram): if mel_spectrogram.dim() == 4: mel_spectrogram = mel_spectrogram.squeeze(1) waveform = self.vocoder(mel_spectrogram) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 waveform = waveform.cpu().float() return waveform def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype): if not is_librosa_available(): logger.info( "Automatic scoring of the generated audio waveforms against the input prompt text requires the " "`librosa` package to resample the generated waveforms. Returning the audios in the order they were " "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`." ) return audio inputs = self.tokenizer(text, return_tensors="pt", padding=True) resampled_audio = librosa.resample( audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate ) inputs["input_features"] = self.feature_extractor( list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate ).input_features.type(dtype) inputs = inputs.to(device) # compute the audio-text similarity score using the CLAP model logits_per_text = self.text_encoder(**inputs).logits_per_text # sort by the highest matching generations per prompt indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt] audio = torch.index_select(audio, 0, indices.reshape(-1).cpu()) return audio # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, audio_length_in_s, vocoder_upsample_factor, callback_steps, transcription=None, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, generated_prompt_embeds=None, negative_generated_prompt_embeds=None, attention_mask=None, negative_attention_mask=None, ): min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor if audio_length_in_s < min_audio_length_in_s: raise ValueError( f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " f"is {audio_length_in_s}." ) if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: raise ValueError( f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " f"{self.vae_scale_factor}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None): raise ValueError( "Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave " "`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None: raise ValueError( "Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that" "both arguments are specified" ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]: raise ValueError( "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" ) if transcription is None: if self.text_encoder_2.config.model_type == "vits": raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription") elif transcription is not None and ( not isinstance(transcription, str) and not isinstance(transcription, list) ): raise ValueError(f"`transcription` has to be of type `str` or `list` but is {type(transcription)}") if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None: if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape: raise ValueError( "`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when " f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != " f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}." ) if ( negative_attention_mask is not None and negative_attention_mask.shape != negative_prompt_embeds.shape[:2] ): raise ValueError( "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(self.vocoder.config.model_in_dim) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, transcription: Union[str, List[str]] = None, audio_length_in_s: Optional[float] = None, num_inference_steps: int = 200, guidance_scale: float = 3.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_waveforms_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, generated_prompt_embeds: Optional[torch.Tensor] = None, negative_generated_prompt_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None, negative_attention_mask: Optional[torch.LongTensor] = None, max_new_tokens: Optional[int] = None, return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, output_type: Optional[str] = "np", ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. transcription (`str` or `List[str]`, *optional*):\ The transcript for text to speech. audio_length_in_s (`int`, *optional*, defaults to 10.24): The length of the generated audio sample in seconds. num_inference_steps (`int`, *optional*, defaults to 200): The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 3.5): A higher guidance scale value encourages the model to generate audio that is closely linked to the text `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in audio generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_waveforms_per_prompt (`int`, *optional*, defaults to 1): The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic scoring is performed between the generated outputs and the text prompt. This scoring ranks the generated waveforms based on their cosine similarity with the text input in the joint text-audio embedding space. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. generated_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_generated_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from `negative_prompt` input argument. attention_mask (`torch.LongTensor`, *optional*): Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will be computed from `prompt` input argument. negative_attention_mask (`torch.LongTensor`, *optional*): Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention mask will be computed from `negative_prompt` input argument. max_new_tokens (`int`, *optional*, defaults to None): Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will be taken from the config of the model. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion model (LDM) output. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated audio. """ # 0. Convert audio input length from seconds to spectrogram height vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate if audio_length_in_s is None: audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor height = int(audio_length_in_s / vocoder_upsample_factor) original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) if height % self.vae_scale_factor != 0: height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor logger.info( f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " f"denoising process." ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, audio_length_in_s, vocoder_upsample_factor, callback_steps, transcription, negative_prompt, prompt_embeds, negative_prompt_embeds, generated_prompt_embeds, negative_generated_prompt_embeds, attention_mask, negative_attention_mask, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt( prompt, device, num_waveforms_per_prompt, do_classifier_free_guidance, transcription, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, generated_prompt_embeds=generated_prompt_embeds, negative_generated_prompt_embeds=negative_generated_prompt_embeds, attention_mask=attention_mask, negative_attention_mask=negative_attention_mask, max_new_tokens=max_new_tokens, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_waveforms_per_prompt, num_channels_latents, height, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=generated_prompt_embeds, encoder_hidden_states_1=prompt_embeds, encoder_attention_mask_1=attention_mask, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) self.maybe_free_model_hooks() # 8. Post-processing if not output_type == "latent": latents = 1 / self.vae.config.scaling_factor * latents mel_spectrogram = self.vae.decode(latents).sample else: return AudioPipelineOutput(audios=latents) audio = self.mel_spectrogram_to_waveform(mel_spectrogram) audio = audio[:, :original_waveform_length] # 9. Automatic scoring if num_waveforms_per_prompt > 1 and prompt is not None: audio = self.score_waveforms( text=prompt, audio=audio, num_waveforms_per_prompt=num_waveforms_per_prompt, device=device, dtype=prompt_embeds.dtype, ) if output_type == "np": audio = audio.numpy() if not return_dict: return (audio,) return AudioPipelineOutput(audios=audio) ================================================ FILE: src/diffusers/pipelines/aura_flow/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_aura_flow import AuraFlowPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py ================================================ # Copyright 2024 AuraFlow Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Tuple, Union import torch from transformers import T5Tokenizer, UMT5EncoderModel from ...image_processor import VaeImageProcessor from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AuraFlowPipeline >>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> image = pipe(prompt).images[0] >>> image.save("aura_flow.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class AuraFlowPipeline(DiffusionPipeline): r""" Args: tokenizer (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). text_encoder ([`T5EncoderModel`]): Frozen text-encoder. AuraFlow uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. transformer ([`AuraFlowTransformer2DModel`]): Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" def __init__( self, tokenizer: T5Tokenizer, text_encoder: UMT5EncoderModel, vae: AutoencoderKL, transformer: AuraFlowTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def check_inputs( self, prompt, height, width, negative_prompt, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: raise ValueError( "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" f" {negative_prompt_attention_mask.shape}." ) def encode_prompt( self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] = None, do_classifier_free_guidance: bool = True, num_images_per_prompt: int = 1, device: Optional[torch.device] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: int = 256, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. """ if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] max_length = max_sequence_length if prompt_embeds is None: text_inputs = self.tokenizer( prompt, truncation=True, max_length=max_length, padding="max_length", return_tensors="pt", ) text_input_ids = text_inputs["input_ids"] untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because T5 can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) text_inputs = {k: v.to(device) for k, v in text_inputs.items()} prompt_embeds = self.text_encoder(**text_inputs)[0] prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) prompt_embeds = prompt_embeds * prompt_attention_mask if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.transformer is not None: dtype = self.transformer.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_attention_mask = prompt_attention_mask.reshape(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, truncation=True, max_length=max_length, padding="max_length", return_tensors="pt", ) uncond_input = {k: v.to(device) for k, v in uncond_input.items()} negative_prompt_embeds = self.text_encoder(**uncond_input)[0] negative_prompt_attention_mask = ( uncond_input["attention_mask"].unsqueeze(-1).expand(negative_prompt_embeds.shape) ) negative_prompt_embeds = negative_prompt_embeds * negative_prompt_attention_mask if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.reshape(bs_embed, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) else: negative_prompt_embeds = None negative_prompt_attention_mask = None return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): if latents is not None: return latents.to(device=device, dtype=dtype) shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, height: Optional[int] = 1024, width: Optional[int] = 1024, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: int = 256, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for best results. width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor self.check_inputs( prompt, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, ) # 2. Determine batch size. if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt ( prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0]) timestep = timestep.to(latents.device, dtype=latents.dtype) # predict noise model_output noise_pred = self.transformer( latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if output_type == "latent": image = latents else: # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/auto_pipeline.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from huggingface_hub.utils import validate_hf_hub_args from ..configuration_utils import ConfigMixin from ..utils import is_sentencepiece_available from .aura_flow import AuraFlowPipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, ) from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .flux import FluxPipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyImg2ImgPipeline, KandinskyInpaintCombinedPipeline, KandinskyInpaintPipeline, KandinskyPipeline, ) from .kandinsky2_2 import ( KandinskyV22CombinedPipeline, KandinskyV22Img2ImgCombinedPipeline, KandinskyV22Img2ImgPipeline, KandinskyV22InpaintCombinedPipeline, KandinskyV22InpaintPipeline, KandinskyV22Pipeline, ) from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .pag import ( HunyuanDiTPAGPipeline, PixArtSigmaPAGPipeline, StableDiffusion3PAGPipeline, StableDiffusionControlNetPAGPipeline, StableDiffusionPAGPipeline, StableDiffusionXLControlNetPAGPipeline, StableDiffusionXLPAGImg2ImgPipeline, StableDiffusionXLPAGInpaintPipeline, StableDiffusionXLPAGPipeline, ) from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline from .stable_diffusion import ( StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionPipeline, ) from .stable_diffusion_3 import ( StableDiffusion3Img2ImgPipeline, StableDiffusion3InpaintPipeline, StableDiffusion3Pipeline, ) from .stable_diffusion_xl import ( StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( [ ("stable-diffusion", StableDiffusionPipeline), ("stable-diffusion-xl", StableDiffusionXLPipeline), ("stable-diffusion-3", StableDiffusion3Pipeline), ("stable-diffusion-3-pag", StableDiffusion3PAGPipeline), ("if", IFPipeline), ("hunyuan", HunyuanDiTPipeline), ("hunyuan-pag", HunyuanDiTPAGPipeline), ("kandinsky", KandinskyCombinedPipeline), ("kandinsky22", KandinskyV22CombinedPipeline), ("kandinsky3", Kandinsky3Pipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), ("wuerstchen", WuerstchenCombinedPipeline), ("cascade", StableCascadeCombinedPipeline), ("lcm", LatentConsistencyModelPipeline), ("pixart-alpha", PixArtAlphaPipeline), ("pixart-sigma", PixArtSigmaPipeline), ("stable-diffusion-pag", StableDiffusionPAGPipeline), ("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline), ("pixart-sigma-pag", PixArtSigmaPAGPipeline), ("auraflow", AuraFlowPipeline), ("flux", FluxPipeline), ] ) AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( [ ("stable-diffusion", StableDiffusionImg2ImgPipeline), ("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline), ("stable-diffusion-3", StableDiffusion3Img2ImgPipeline), ("if", IFImg2ImgPipeline), ("kandinsky", KandinskyImg2ImgCombinedPipeline), ("kandinsky22", KandinskyV22Img2ImgCombinedPipeline), ("kandinsky3", Kandinsky3Img2ImgPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), ] ) AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict( [ ("stable-diffusion", StableDiffusionInpaintPipeline), ("stable-diffusion-xl", StableDiffusionXLInpaintPipeline), ("stable-diffusion-3", StableDiffusion3InpaintPipeline), ("if", IFInpaintingPipeline), ("kandinsky", KandinskyInpaintCombinedPipeline), ("kandinsky22", KandinskyV22InpaintCombinedPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ] ) _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict( [ ("kandinsky", KandinskyPipeline), ("kandinsky22", KandinskyV22Pipeline), ("wuerstchen", WuerstchenDecoderPipeline), ("cascade", StableCascadeDecoderPipeline), ] ) _AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict( [ ("kandinsky", KandinskyImg2ImgPipeline), ("kandinsky22", KandinskyV22Img2ImgPipeline), ] ) _AUTO_INPAINT_DECODER_PIPELINES_MAPPING = OrderedDict( [ ("kandinsky", KandinskyInpaintPipeline), ("kandinsky22", KandinskyV22InpaintPipeline), ] ) if is_sentencepiece_available(): from .kolors import KolorsPipeline from .pag import KolorsPAGPipeline AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors-pag"] = KolorsPAGPipeline AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline SUPPORTED_TASKS_MAPPINGS = [ AUTO_TEXT2IMAGE_PIPELINES_MAPPING, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, AUTO_INPAINT_PIPELINES_MAPPING, _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING, _AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING, _AUTO_INPAINT_DECODER_PIPELINES_MAPPING, ] def _get_connected_pipeline(pipeline_cls): # for now connected pipelines can only be loaded from decoder pipelines, such as kandinsky-community/kandinsky-2-2-decoder if pipeline_cls in _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING.values(): return _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False ) if pipeline_cls in _AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING.values(): return _get_task_class( AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False ) if pipeline_cls in _AUTO_INPAINT_DECODER_PIPELINES_MAPPING.values(): return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False) def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True): def get_model(pipeline_class_name): for task_mapping in SUPPORTED_TASKS_MAPPINGS: for model_name, pipeline in task_mapping.items(): if pipeline.__name__ == pipeline_class_name: return model_name model_name = get_model(pipeline_class_name) if model_name is not None: task_class = mapping.get(model_name, None) if task_class is not None: return task_class if throw_error_if_not_exist: raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}") class AutoPipelineForText2Image(ConfigMixin): r""" [`AutoPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The specific underlying pipeline class is automatically selected from either the [`~AutoPipelineForText2Image.from_pretrained`] or [`~AutoPipelineForText2Image.from_pipe`] methods. This class cannot be instantiated using `__init__()` (throws an error). Class attributes: - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the diffusion pipeline's components. """ config_name = "model_index.json" def __init__(self, *args, **kwargs): raise EnvironmentError( f"{self.__class__.__name__} is designed to be instantiated " f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." ) @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_or_path, **kwargs): r""" Instantiates a text-to-image Pytorch diffusion pipeline from pretrained pipeline weight. The from_pretrained() method takes care of returning the correct pipeline class instance by: 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its config object 2. Find the text-to-image pipeline linked to the pipeline class using pattern matching on pipeline class name. If a `controlnet` argument is passed, it will instantiate a [`StableDiffusionControlNetPipeline`] object. The pipeline is set in evaluation mode (`model.eval()`) by default. If you get the error message below, you need to finetune the weights for your downstream task: ``` Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` Parameters: pretrained_model_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. custom_revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id similar to `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn’t need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if device_map contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example below for more information. variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `huggingface-cli login`. Examples: ```py >>> from diffusers import AutoPipelineForText2Image >>> pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5") >>> image = pipeline(prompt).images[0] ``` """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) token = kwargs.pop("token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) load_config_kwargs = { "cache_dir": cache_dir, "force_download": force_download, "proxies": proxies, "token": token, "local_files_only": local_files_only, "revision": revision, } config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) orig_class_name = config["_class_name"] if "controlnet" in kwargs: orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) kwargs = {**load_config_kwargs, **kwargs} return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) @classmethod def from_pipe(cls, pipeline, **kwargs): r""" Instantiates a text-to-image Pytorch diffusion pipeline from another instantiated diffusion pipeline class. The from_pipe() method takes care of returning the correct pipeline class instance by finding the text-to-image pipeline linked to the pipeline class using pattern matching on pipeline class name. All the modules the pipeline contains will be used to initialize the new pipeline without reallocating additional memory. The pipeline is set in evaluation mode (`model.eval()`) by default. Parameters: pipeline (`DiffusionPipeline`): an instantiated `DiffusionPipeline` object ```py >>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image >>> pipe_i2i = AutoPipelineForImage2Image.from_pretrained( ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False ... ) >>> pipe_t2i = AutoPipelineForText2Image.from_pipe(pipe_i2i) >>> image = pipe_t2i(prompt).images[0] ``` """ original_config = dict(pipeline.config) original_cls_name = pipeline.__class__.__name__ # derive the pipeline class to instantiate text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, original_cls_name) if "controlnet" in kwargs: if kwargs["controlnet"] is not None: to_replace = "PAGPipeline" if "PAG" in text_2_image_cls.__name__ else "Pipeline" text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, text_2_image_cls.__name__.replace("ControlNet", "").replace(to_replace, "ControlNet" + to_replace), ) else: text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, text_2_image_cls.__name__.replace("ControlNet", ""), ) if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, text_2_image_cls.__name__.replace("PAG", "").replace("Pipeline", "PAGPipeline"), ) else: text_2_image_cls = _get_task_class( AUTO_TEXT2IMAGE_PIPELINES_MAPPING, text_2_image_cls.__name__.replace("PAG", ""), ) # define expected module and optional kwargs given the pipeline signature expected_modules, optional_kwargs = text_2_image_cls._get_signature_keys(text_2_image_cls) pretrained_model_name_or_path = original_config.pop("_name_or_path", None) # allow users pass modules in `kwargs` to override the original pipeline's components passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} original_class_obj = { k: pipeline.components[k] for k, v in pipeline.components.items() if k in expected_modules and k not in passed_class_obj } # allow users pass optional kwargs to override the original pipelines config attribute passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} original_pipe_kwargs = { k: original_config[k] for k, v in original_config.items() if k in optional_kwargs and k not in passed_pipe_kwargs } # config that were not expected by original pipeline is stored as private attribute # we will pass them as optional arguments if they can be accepted by the pipeline additional_pipe_kwargs = [ k[1:] for k in original_config.keys() if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs ] for k in additional_pipe_kwargs: original_pipe_kwargs[k] = original_config.pop(f"_{k}") text_2_image_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs} # store unused config as private attribute unused_original_config = { f"{'' if k.startswith('_') else '_'}{k}": original_config[k] for k, v in original_config.items() if k not in text_2_image_kwargs } missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(text_2_image_kwargs.keys()) if len(missing_modules) > 0: raise ValueError( f"Pipeline {text_2_image_cls} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed" ) model = text_2_image_cls(**text_2_image_kwargs) model.register_to_config(_name_or_path=pretrained_model_name_or_path) model.register_to_config(**unused_original_config) return model class AutoPipelineForImage2Image(ConfigMixin): r""" [`AutoPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The specific underlying pipeline class is automatically selected from either the [`~AutoPipelineForImage2Image.from_pretrained`] or [`~AutoPipelineForImage2Image.from_pipe`] methods. This class cannot be instantiated using `__init__()` (throws an error). Class attributes: - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the diffusion pipeline's components. """ config_name = "model_index.json" def __init__(self, *args, **kwargs): raise EnvironmentError( f"{self.__class__.__name__} is designed to be instantiated " f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." ) @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_or_path, **kwargs): r""" Instantiates a image-to-image Pytorch diffusion pipeline from pretrained pipeline weight. The from_pretrained() method takes care of returning the correct pipeline class instance by: 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its config object 2. Find the image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class name. If a `controlnet` argument is passed, it will instantiate a [`StableDiffusionControlNetImg2ImgPipeline`] object. The pipeline is set in evaluation mode (`model.eval()`) by default. If you get the error message below, you need to finetune the weights for your downstream task: ``` Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` Parameters: pretrained_model_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. custom_revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id similar to `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn’t need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if device_map contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example below for more information. variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `huggingface-cli login`. Examples: ```py >>> from diffusers import AutoPipelineForImage2Image >>> pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5") >>> image = pipeline(prompt, image).images[0] ``` """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) token = kwargs.pop("token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) load_config_kwargs = { "cache_dir": cache_dir, "force_download": force_download, "proxies": proxies, "token": token, "local_files_only": local_files_only, "revision": revision, } config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) orig_class_name = config["_class_name"] if "controlnet" in kwargs: orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline") image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) kwargs = {**load_config_kwargs, **kwargs} return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) @classmethod def from_pipe(cls, pipeline, **kwargs): r""" Instantiates a image-to-image Pytorch diffusion pipeline from another instantiated diffusion pipeline class. The from_pipe() method takes care of returning the correct pipeline class instance by finding the image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class name. All the modules the pipeline contains will be used to initialize the new pipeline without reallocating additional memory. The pipeline is set in evaluation mode (`model.eval()`) by default. Parameters: pipeline (`DiffusionPipeline`): an instantiated `DiffusionPipeline` object Examples: ```py >>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image >>> pipe_t2i = AutoPipelineForText2Image.from_pretrained( ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False ... ) >>> pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i) >>> image = pipe_i2i(prompt, image).images[0] ``` """ original_config = dict(pipeline.config) original_cls_name = pipeline.__class__.__name__ # derive the pipeline class to instantiate image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, original_cls_name) if "controlnet" in kwargs: if kwargs["controlnet"] is not None: to_replace = "Img2ImgPipeline" if "PAG" in image_2_image_cls.__name__: to_replace = "PAG" + to_replace image_2_image_cls = _get_task_class( AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, image_2_image_cls.__name__.replace("ControlNet", "").replace( to_replace, "ControlNet" + to_replace ), ) else: image_2_image_cls = _get_task_class( AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, image_2_image_cls.__name__.replace("ControlNet", ""), ) if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: image_2_image_cls = _get_task_class( AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, image_2_image_cls.__name__.replace("PAG", "").replace("Img2ImgPipeline", "PAGImg2ImgPipeline"), ) else: image_2_image_cls = _get_task_class( AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, image_2_image_cls.__name__.replace("PAG", ""), ) # define expected module and optional kwargs given the pipeline signature expected_modules, optional_kwargs = image_2_image_cls._get_signature_keys(image_2_image_cls) pretrained_model_name_or_path = original_config.pop("_name_or_path", None) # allow users pass modules in `kwargs` to override the original pipeline's components passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} original_class_obj = { k: pipeline.components[k] for k, v in pipeline.components.items() if k in expected_modules and k not in passed_class_obj } # allow users pass optional kwargs to override the original pipelines config attribute passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} original_pipe_kwargs = { k: original_config[k] for k, v in original_config.items() if k in optional_kwargs and k not in passed_pipe_kwargs } # config attribute that were not expected by original pipeline is stored as its private attribute # we will pass them as optional arguments if they can be accepted by the pipeline additional_pipe_kwargs = [ k[1:] for k in original_config.keys() if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs ] for k in additional_pipe_kwargs: original_pipe_kwargs[k] = original_config.pop(f"_{k}") image_2_image_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs} # store unused config as private attribute unused_original_config = { f"{'' if k.startswith('_') else '_'}{k}": original_config[k] for k, v in original_config.items() if k not in image_2_image_kwargs } missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(image_2_image_kwargs.keys()) if len(missing_modules) > 0: raise ValueError( f"Pipeline {image_2_image_cls} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed" ) model = image_2_image_cls(**image_2_image_kwargs) model.register_to_config(_name_or_path=pretrained_model_name_or_path) model.register_to_config(**unused_original_config) return model class AutoPipelineForInpainting(ConfigMixin): r""" [`AutoPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The specific underlying pipeline class is automatically selected from either the [`~AutoPipelineForInpainting.from_pretrained`] or [`~AutoPipelineForInpainting.from_pipe`] methods. This class cannot be instantiated using `__init__()` (throws an error). Class attributes: - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the diffusion pipeline's components. """ config_name = "model_index.json" def __init__(self, *args, **kwargs): raise EnvironmentError( f"{self.__class__.__name__} is designed to be instantiated " f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." ) @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_or_path, **kwargs): r""" Instantiates a inpainting Pytorch diffusion pipeline from pretrained pipeline weight. The from_pretrained() method takes care of returning the correct pipeline class instance by: 1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its config object 2. Find the inpainting pipeline linked to the pipeline class using pattern matching on pipeline class name. If a `controlnet` argument is passed, it will instantiate a [`StableDiffusionControlNetInpaintPipeline`] object. The pipeline is set in evaluation mode (`model.eval()`) by default. If you get the error message below, you need to finetune the weights for your downstream task: ``` Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` Parameters: pretrained_model_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. custom_revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id similar to `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn’t need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if device_map contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example below for more information. variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `huggingface-cli login`. Examples: ```py >>> from diffusers import AutoPipelineForInpainting >>> pipeline = AutoPipelineForInpainting.from_pretrained("runwayml/stable-diffusion-v1-5") >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0] ``` """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) token = kwargs.pop("token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) load_config_kwargs = { "cache_dir": cache_dir, "force_download": force_download, "proxies": proxies, "token": token, "local_files_only": local_files_only, "revision": revision, } config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) orig_class_name = config["_class_name"] if "controlnet" in kwargs: orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline") inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) kwargs = {**load_config_kwargs, **kwargs} return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs) @classmethod def from_pipe(cls, pipeline, **kwargs): r""" Instantiates a inpainting Pytorch diffusion pipeline from another instantiated diffusion pipeline class. The from_pipe() method takes care of returning the correct pipeline class instance by finding the inpainting pipeline linked to the pipeline class using pattern matching on pipeline class name. All the modules the pipeline class contain will be used to initialize the new pipeline without reallocating additional memory. The pipeline is set in evaluation mode (`model.eval()`) by default. Parameters: pipeline (`DiffusionPipeline`): an instantiated `DiffusionPipeline` object Examples: ```py >>> from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting >>> pipe_t2i = AutoPipelineForText2Image.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", requires_safety_checker=False ... ) >>> pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_t2i) >>> image = pipe_inpaint(prompt, image=init_image, mask_image=mask_image).images[0] ``` """ original_config = dict(pipeline.config) original_cls_name = pipeline.__class__.__name__ # derive the pipeline class to instantiate inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, original_cls_name) if "controlnet" in kwargs: if kwargs["controlnet"] is not None: inpainting_cls = _get_task_class( AUTO_INPAINT_PIPELINES_MAPPING, inpainting_cls.__name__.replace("ControlNet", "").replace( "InpaintPipeline", "ControlNetInpaintPipeline" ), ) else: inpainting_cls = _get_task_class( AUTO_INPAINT_PIPELINES_MAPPING, inpainting_cls.__name__.replace("ControlNetInpaintPipeline", "InpaintPipeline"), ) if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: inpainting_cls = _get_task_class( AUTO_INPAINT_PIPELINES_MAPPING, inpainting_cls.__name__.replace("PAG", "").replace("InpaintPipeline", "PAGInpaintPipeline"), ) else: inpainting_cls = _get_task_class( AUTO_INPAINT_PIPELINES_MAPPING, inpainting_cls.__name__.replace("PAGInpaintPipeline", "InpaintPipeline"), ) # define expected module and optional kwargs given the pipeline signature expected_modules, optional_kwargs = inpainting_cls._get_signature_keys(inpainting_cls) pretrained_model_name_or_path = original_config.pop("_name_or_path", None) # allow users pass modules in `kwargs` to override the original pipeline's components passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} original_class_obj = { k: pipeline.components[k] for k, v in pipeline.components.items() if k in expected_modules and k not in passed_class_obj } # allow users pass optional kwargs to override the original pipelines config attribute passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} original_pipe_kwargs = { k: original_config[k] for k, v in original_config.items() if k in optional_kwargs and k not in passed_pipe_kwargs } # config that were not expected by original pipeline is stored as private attribute # we will pass them as optional arguments if they can be accepted by the pipeline additional_pipe_kwargs = [ k[1:] for k in original_config.keys() if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs ] for k in additional_pipe_kwargs: original_pipe_kwargs[k] = original_config.pop(f"_{k}") inpainting_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs} # store unused config as private attribute unused_original_config = { f"{'' if k.startswith('_') else '_'}{k}": original_config[k] for k, v in original_config.items() if k not in inpainting_kwargs } missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(inpainting_kwargs.keys()) if len(missing_modules) > 0: raise ValueError( f"Pipeline {inpainting_cls} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed" ) model = inpainting_cls(**inpainting_kwargs) model.register_to_config(_name_or_path=pretrained_model_name_or_path) model.register_to_config(**unused_original_config) return model ================================================ FILE: src/diffusers/pipelines/blip_diffusion/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline else: from .blip_image_processing import BlipImageProcessor from .modeling_blip2 import Blip2QFormerModel from .modeling_ctx_clip import ContextCLIPTextModel from .pipeline_blip_diffusion import BlipDiffusionPipeline ================================================ FILE: src/diffusers/pipelines/blip_diffusion/blip_image_processing.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Image processor class for BLIP.""" from typing import Dict, List, Optional, Union import numpy as np import torch from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format from transformers.image_utils import ( OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension, ImageInput, PILImageResampling, infer_channel_dimension_format, is_scaled_image, make_list_of_images, to_numpy_array, valid_images, ) from transformers.utils import TensorType, is_vision_available, logging from diffusers.utils import numpy_to_pil if is_vision_available(): import PIL.Image logger = logging.get_logger(__name__) # We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop # Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor class BlipImageProcessor(BaseImageProcessor): r""" Constructs a BLIP image processor. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the `do_resize` parameter in the `preprocess` method. size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` method. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be overridden by the `resample` parameter in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `True`): Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` parameter in the `preprocess` method. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be overridden by the `rescale_factor` parameter in the `preprocess` method. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be overridden by the `image_mean` parameter in the `preprocess` method. image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method. do_convert_rgb (`bool`, *optional*, defaults to `True`): Whether to convert the image to RGB. """ model_input_names = ["pixel_values"] def __init__( self, do_resize: bool = True, size: Dict[str, int] = None, resample: PILImageResampling = PILImageResampling.BICUBIC, do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = True, do_center_crop: bool = True, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"height": 224, "width": 224} size = get_size_dict(size, default_to_square=True) self.do_resize = do_resize self.size = size self.resample = resample self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.do_convert_rgb = do_convert_rgb self.do_center_crop = do_center_crop # Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC def resize( self, image: np.ndarray, size: Dict[str, int], resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> np.ndarray: """ Resize an image to `(size["height"], size["width"])`. Args: image (`np.ndarray`): Image to resize. size (`Dict[str, int]`): Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the output image. If unset, the channel dimension format of the input image is used. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. Returns: `np.ndarray`: The resized image. """ size = get_size_dict(size) if "height" not in size or "width" not in size: raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") output_size = (size["height"], size["width"]) return resize( image, size=output_size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs, ) def preprocess( self, images: ImageInput, do_resize: Optional[bool] = None, size: Optional[Dict[str, int]] = None, resample: PILImageResampling = None, do_rescale: Optional[bool] = None, do_center_crop: Optional[bool] = None, rescale_factor: Optional[float] = None, do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, return_tensors: Optional[Union[str, TensorType]] = None, do_convert_rgb: bool = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> PIL.Image.Image: """ Preprocess an image or batch of images. Args: images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): Controls the size of the image after `resize`. The shortest edge of the image is resized to `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest edge equal to `int(size["shortest_edge"] * (1333 / 800))`. resample (`PILImageResampling`, *optional*, defaults to `self.resample`): Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): Whether to rescale the image values between [0 - 1]. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): Rescale factor to rescale the image by if `do_rescale` is set to `True`. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): Whether to normalize the image. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): Image mean to normalize the image by if `do_normalize` is set to `True`. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): Image standard deviation to normalize the image by if `do_normalize` is set to `True`. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): Whether to convert the image to RGB. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): The channel dimension format for the output image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - Unset: Use the channel dimension format of the input image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of: - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. """ do_resize = do_resize if do_resize is not None else self.do_resize resample = resample if resample is not None else self.resample do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor do_normalize = do_normalize if do_normalize is not None else self.do_normalize image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) images = make_list_of_images(images) if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) if do_resize and size is None or resample is None: raise ValueError("Size and resample must be specified if do_resize is True.") if do_rescale and rescale_factor is None: raise ValueError("Rescale factor must be specified if do_rescale is True.") if do_normalize and (image_mean is None or image_std is None): raise ValueError("Image mean and std must be specified if do_normalize is True.") # PIL RGBA images are converted to RGB if do_convert_rgb: images = [convert_to_rgb(image) for image in images] # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] if is_scaled_image(images[0]) and do_rescale: logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." ) if input_data_format is None: # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) if do_resize: images = [ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) for image in images ] if do_rescale: images = [ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) for image in images ] if do_normalize: images = [ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) for image in images ] if do_center_crop: images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images] images = [ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images ] encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) return encoded_outputs # Follows diffusers.VaeImageProcessor.postprocess def postprocess(self, sample: torch.Tensor, output_type: str = "pil"): if output_type not in ["pt", "np", "pil"]: raise ValueError( f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" ) # Equivalent to diffusers.VaeImageProcessor.denormalize sample = (sample / 2 + 0.5).clamp(0, 1) if output_type == "pt": return sample # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "np": return sample # Output_type must be 'pil' sample = numpy_to_pil(sample) return sample ================================================ FILE: src/diffusers/pipelines/blip_diffusion/modeling_blip2.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers import BertTokenizer from transformers.activations import QuickGELUActivation as QuickGELU from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions, ) from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig from transformers.models.blip_2.modeling_blip_2 import ( Blip2Encoder, Blip2PreTrainedModel, Blip2QFormerAttention, Blip2QFormerIntermediate, Blip2QFormerOutput, ) from transformers.pytorch_utils import apply_chunking_to_forward from transformers.utils import ( logging, replace_return_docstrings, ) logger = logging.get_logger(__name__) # There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py. # But it doesn't support getting multimodal embeddings. So, this module can be # replaced with a future `transformers` version supports that. class Blip2TextEmbeddings(nn.Module): """Construct the embeddings from word and position embeddings.""" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.config = config def forward( self, input_ids=None, position_ids=None, query_embeds=None, past_key_values_length=0, ): if input_ids is not None: seq_length = input_ids.size()[1] else: seq_length = 0 if position_ids is None: position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() if input_ids is not None: embeddings = self.word_embeddings(input_ids) if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings if query_embeds is not None: batch_size = embeddings.shape[0] # repeat the query embeddings for batch size query_embeds = query_embeds.repeat(batch_size, 1, 1) embeddings = torch.cat((query_embeds, embeddings), dim=1) else: embeddings = query_embeds embeddings = embeddings.to(query_embeds.dtype) embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings # Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 class Blip2VisionEmbeddings(nn.Module): def __init__(self, config: Blip2VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) self.patch_embedding = nn.Conv2d( in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size = pixel_values.shape[0] target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) return embeddings # The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings class Blip2QFormerEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList( [Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, query_length=0, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions, query_length) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, query_length, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if layer_module.has_cross_attention: all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) # The layers making up the Qformer encoder class Blip2QFormerLayer(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = Blip2QFormerAttention(config) self.layer_idx = layer_idx if layer_idx % config.cross_attention_frequency == 0: self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True) self.has_cross_attention = True else: self.has_cross_attention = False self.intermediate = Blip2QFormerIntermediate(config) self.intermediate_query = Blip2QFormerIntermediate(config) self.output_query = Blip2QFormerOutput(config) self.output = Blip2QFormerOutput(config) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, query_length=0, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] if query_length > 0: query_attention_output = attention_output[:, :query_length, :] if self.has_cross_attention: if encoder_hidden_states is None: raise ValueError("encoder_hidden_states must be given for cross-attention layers") cross_attention_outputs = self.crossattention( query_attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions=output_attentions, ) query_attention_output = cross_attention_outputs[0] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:-1] layer_output = apply_chunking_to_forward( self.feed_forward_chunk_query, self.chunk_size_feed_forward, self.seq_len_dim, query_attention_output, ) if attention_output.shape[1] > query_length: layer_output_text = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[:, query_length:, :], ) layer_output = torch.cat([layer_output, layer_output_text], dim=1) else: layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, ) outputs = (layer_output,) + outputs outputs = outputs + (present_key_value,) return outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output def feed_forward_chunk_query(self, attention_output): intermediate_output = self.intermediate_query(attention_output) layer_output = self.output_query(intermediate_output, attention_output) return layer_output # ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder class ProjLayer(nn.Module): def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12): super().__init__() # Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm self.dense1 = nn.Linear(in_dim, hidden_dim) self.act_fn = QuickGELU() self.dense2 = nn.Linear(hidden_dim, out_dim) self.dropout = nn.Dropout(drop_p) self.LayerNorm = nn.LayerNorm(out_dim, eps=eps) def forward(self, x): x_in = x x = self.LayerNorm(x) x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in return x # Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2 class Blip2VisionModel(Blip2PreTrainedModel): main_input_name = "pixel_values" config_class = Blip2VisionConfig def __init__(self, config: Blip2VisionConfig): super().__init__(config) self.config = config embed_dim = config.hidden_size self.embeddings = Blip2VisionEmbeddings(config) self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = Blip2Encoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_init() @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig) def forward( self, pixel_values: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layernorm(hidden_states) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def get_input_embeddings(self): return self.embeddings # Qformer model, used to get multimodal embeddings from the text and image inputs class Blip2QFormerModel(Blip2PreTrainedModel): """ Querying Transformer (Q-Former), used in BLIP-2. """ def __init__(self, config: Blip2Config): super().__init__(config) self.config = config self.embeddings = Blip2TextEmbeddings(config.qformer_config) self.visual_encoder = Blip2VisionModel(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) if not hasattr(config, "tokenizer") or config.tokenizer is None: self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right") else: self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right") self.tokenizer.add_special_tokens({"bos_token": "[DEC]"}) self.proj_layer = ProjLayer( in_dim=config.qformer_config.hidden_size, out_dim=config.qformer_config.hidden_size, hidden_dim=config.qformer_config.hidden_size * 4, drop_p=0.1, eps=1e-12, ) self.encoder = Blip2QFormerEncoder(config.qformer_config) self.post_init() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def get_extended_attention_mask( self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: torch.device, has_query: bool = False, ) -> torch.Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (`Tuple[int]`): The shape of the input to the model. device (`torch.device`): The device of the input to the model. Returns: `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. """ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( input_shape, attention_mask.shape ) ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward( self, text_input=None, image_input=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" encoder_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. past_key_values (`tuple(tuple(torch.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of: shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. use_cache (`bool`, `optional`): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ text = self.tokenizer(text_input, return_tensors="pt", padding=True) text = text.to(self.device) input_ids = text.input_ids batch_size = input_ids.shape[0] query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device) attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # past_key_values_length past_key_values_length = ( past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 ) query_length = self.query_tokens.shape[1] embedding_output = self.embeddings( input_ids=input_ids, query_embeds=self.query_tokens, past_key_values_length=past_key_values_length, ) # embedding_output = self.layernorm(query_embeds) # embedding_output = self.dropout(embedding_output) input_shape = embedding_output.size()[:-1] batch_size, seq_length = input_shape device = embedding_output.device image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state # image_embeds_frozen = torch.ones_like(image_embeds_frozen) encoder_hidden_states = image_embeds_frozen if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: if isinstance(encoder_hidden_states, list): encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() else: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if isinstance(encoder_attention_mask, list): encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] elif encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, query_length=query_length, ) sequence_output = encoder_outputs[0] pooled_output = sequence_output[:, 0, :] if not return_dict: return self.proj_layer(sequence_output[:, :query_length, :]) return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) ================================================ FILE: src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py ================================================ # Copyright 2024 Salesforce.com, inc. # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import torch from torch import nn from transformers import CLIPPreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPooling from transformers.models.clip.configuration_clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPEncoder def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip # Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer # They pass through the clip model, along with the text embeddings, and interact with them using self attention class ContextCLIPTextModel(CLIPPreTrainedModel): config_class = CLIPTextConfig _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPTextConfig): super().__init__(config) self.text_model = ContextCLIPTextTransformer(config) # Initialize weights and apply final processing self.post_init() def forward( self, ctx_embeddings: torch.Tensor = None, ctx_begin_pos: list = None, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: return self.text_model( ctx_embeddings=ctx_embeddings, ctx_begin_pos=ctx_begin_pos, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class ContextCLIPTextTransformer(nn.Module): def __init__(self, config: CLIPTextConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = ContextCLIPTextEmbeddings(config) self.encoder = CLIPEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim) def forward( self, ctx_embeddings: torch.Tensor, ctx_begin_pos: list, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is None: raise ValueError("You have to specify either input_ids") input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) hidden_states = self.embeddings( input_ids=input_ids, position_ids=position_ids, ctx_embeddings=ctx_embeddings, ctx_begin_pos=ctx_begin_pos, ) bsz, seq_len = input_shape if ctx_embeddings is not None: seq_len += ctx_embeddings.size(1) # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( hidden_states.device ) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) # text_embeds.shape = [batch_size, sequence_length, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1), ] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def _build_causal_attention_mask(self, bsz, seq_len, dtype): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) mask.fill_(torch.tensor(torch.finfo(dtype).min)) mask.triu_(1) # zero out the lower diagonal mask = mask.unsqueeze(1) # expand mask return mask class ContextCLIPTextEmbeddings(nn.Module): def __init__(self, config: CLIPTextConfig): super().__init__() embed_dim = config.hidden_size self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) def forward( self, ctx_embeddings: torch.Tensor, ctx_begin_pos: list, input_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if ctx_embeddings is None: ctx_len = 0 else: ctx_len = ctx_embeddings.shape[1] seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids) # for each input embeddings, add the ctx embeddings at the correct position input_embeds_ctx = [] bsz = inputs_embeds.shape[0] if ctx_embeddings is not None: for i in range(bsz): cbp = ctx_begin_pos[i] prefix = inputs_embeds[i, :cbp] # remove the special token embedding suffix = inputs_embeds[i, cbp:] input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0)) inputs_embeds = torch.stack(input_embeds_ctx, dim=0) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings ================================================ FILE: src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py ================================================ # Copyright 2024 Salesforce.com, inc. # Copyright 2024 The HuggingFace Team. All rights reserved.# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import PIL.Image import torch from transformers import CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import PNDMScheduler from ...utils import ( logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .blip_image_processing import BlipImageProcessor from .modeling_blip2 import Blip2QFormerModel from .modeling_ctx_clip import ContextCLIPTextModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers.pipelines import BlipDiffusionPipeline >>> from diffusers.utils import load_image >>> import torch >>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained( ... "Salesforce/blipdiffusion", torch_dtype=torch.float16 ... ).to("cuda") >>> cond_subject = "dog" >>> tgt_subject = "dog" >>> text_prompt_input = "swimming underwater" >>> cond_image = load_image( ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg" ... ) >>> guidance_scale = 7.5 >>> num_inference_steps = 25 >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" >>> output = blip_diffusion_pipe( ... text_prompt_input, ... cond_image, ... cond_subject, ... tgt_subject, ... guidance_scale=guidance_scale, ... num_inference_steps=num_inference_steps, ... neg_prompt=negative_prompt, ... height=512, ... width=512, ... ).images >>> output[0].save("image.png") ``` """ class BlipDiffusionPipeline(DiffusionPipeline): """ Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: tokenizer ([`CLIPTokenizer`]): Tokenizer for the text encoder text_encoder ([`ContextCLIPTextModel`]): Text encoder to encode the text prompt vae ([`AutoencoderKL`]): VAE model to map the latents to the image unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. scheduler ([`PNDMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. qformer ([`Blip2QFormerModel`]): QFormer model to get multi-modal embeddings from the text and image. image_processor ([`BlipImageProcessor`]): Image Processor to preprocess and postprocess the image. ctx_begin_pos (int, `optional`, defaults to 2): Position of the context token in the text encoder. """ model_cpu_offload_seq = "qformer->text_encoder->unet->vae" def __init__( self, tokenizer: CLIPTokenizer, text_encoder: ContextCLIPTextModel, vae: AutoencoderKL, unet: UNet2DConditionModel, scheduler: PNDMScheduler, qformer: Blip2QFormerModel, image_processor: BlipImageProcessor, ctx_begin_pos: int = 2, mean: List[float] = None, std: List[float] = None, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, unet=unet, scheduler=scheduler, qformer=qformer, image_processor=image_processor, ) self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) def get_query_embeddings(self, input_image, src_subject): return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) # from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): rv = [] for prompt, tgt_subject in zip(prompts, tgt_subjects): prompt = f"a {tgt_subject} {prompt.strip()}" # a trick to amplify the prompt rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) return rv # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def encode_prompt(self, query_embeds, prompt, device=None): device = device or self._execution_device # embeddings for prompt, with query_embeds as context max_len = self.text_encoder.text_model.config.max_position_embeddings max_len -= self.qformer.config.num_query_tokens tokenized_prompt = self.tokenizer( prompt, padding="max_length", truncation=True, max_length=max_len, return_tensors="pt", ).to(device) batch_size = query_embeds.shape[0] ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size text_embeddings = self.text_encoder( input_ids=tokenized_prompt.input_ids, ctx_embeddings=query_embeds, ctx_begin_pos=ctx_begin_pos, )[0] return text_embeddings @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: List[str], reference_image: PIL.Image.Image, source_subject_category: List[str], target_subject_category: List[str], latents: Optional[torch.Tensor] = None, guidance_scale: float = 7.5, height: int = 512, width: int = 512, num_inference_steps: int = 50, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, neg_prompt: Optional[str] = "", prompt_strength: float = 1.0, prompt_reps: int = 20, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`List[str]`): The prompt or prompts to guide the image generation. reference_image (`PIL.Image.Image`): The reference image to condition the generation on. source_subject_category (`List[str]`): The source subject category. target_subject_category (`List[str]`): The target subject category. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by random sampling. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. height (`int`, *optional*, defaults to 512): The height of the generated image. width (`int`, *optional*, defaults to 512): The width of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. neg_prompt (`str`, *optional*, defaults to ""): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_strength (`float`, *optional*, defaults to 1.0): The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps to amplify the prompt. prompt_reps (`int`, *optional*, defaults to 20): The number of times the prompt is repeated along with prompt_strength to amplify the prompt. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device reference_image = self.image_processor.preprocess( reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" )["pixel_values"] reference_image = reference_image.to(device) if isinstance(prompt, str): prompt = [prompt] if isinstance(source_subject_category, str): source_subject_category = [source_subject_category] if isinstance(target_subject_category, str): target_subject_category = [target_subject_category] batch_size = len(prompt) prompt = self._build_prompt( prompts=prompt, tgt_subjects=target_subject_category, prompt_strength=prompt_strength, prompt_reps=prompt_reps, ) query_embeds = self.get_query_embeddings(reference_image, source_subject_category) text_embeddings = self.encode_prompt(query_embeds, prompt, device) do_classifier_free_guidance = guidance_scale > 1.0 if do_classifier_free_guidance: max_length = self.text_encoder.text_model.config.max_position_embeddings uncond_input = self.tokenizer( [neg_prompt] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt", ) uncond_embeddings = self.text_encoder( input_ids=uncond_input.input_ids.to(device), ctx_embeddings=None, )[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) latents = self.prepare_latents( batch_size=batch_size, num_channels=self.unet.config.in_channels, height=height // scale_down_factor, width=width // scale_down_factor, generator=generator, latents=latents, dtype=self.unet.dtype, device=device, ) # set timesteps extra_set_kwargs = {} self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance do_classifier_free_guidance = guidance_scale > 1.0 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents noise_pred = self.unet( latent_model_input, timestep=t, encoder_hidden_states=text_embeddings, down_block_additional_residuals=None, mid_block_additional_residual=None, )["sample"] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = self.scheduler.step( noise_pred, t, latents, )["prev_sample"] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/cogvideo/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_cogvideox import CogVideoXPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py ================================================ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import math from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union import torch from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import BaseOutput, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```python >>> import torch >>> from diffusers import CogVideoXPipeline >>> from diffusers.utils import export_to_video >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") >>> prompt = ( ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " ... "atmosphere of this unique musical performance." ... ) >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] >>> export_to_video(video, "output.mp4", fps=8) ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps @dataclass class CogVideoXPipelineOutput(BaseOutput): r""" Output class for CogVideo pipelines. Args: frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape `(batch_size, num_frames, channels, height, width)`. """ frames: torch.Tensor class CogVideoXPipeline(DiffusionPipeline): r""" Pipeline for text-to-video generation using CogVideoX. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. text_encoder ([`T5EncoderModel`]): Frozen text-encoder. CogVideoX uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. tokenizer (`T5Tokenizer`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). transformer ([`CogVideoXTransformer3DModel`]): A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", ] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3DModel, scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor_spatial = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.vae_scale_factor_temporal = ( self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_videos_per_prompt: int = 1, max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): Whether to use classifier free guidance or not. num_videos_per_prompt (`int`, *optional*, defaults to 1): Number of videos that should be generated per prompt. torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. device: (`torch.device`, *optional*): torch device dtype: (`torch.dtype`, *optional*): torch dtype """ device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_embeds = self._get_t5_prompt_embeds( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embeds = self._get_t5_prompt_embeds( prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) return prompt_embeds, negative_prompt_embeds def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, (num_frames - 1) // self.vae_scale_factor_temporal + 1, num_channels_latents, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def decode_latents(self, latents: torch.Tensor, num_seconds: int): latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] latents = 1 / self.vae.config.scaling_factor * latents frames = [] for i in range(num_seconds): start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3) current_frames = self.vae.decode(latents[:, :, start_frame:end_frame]).sample frames.append(current_frames) self.vae.clear_fake_context_parallel_cache() frames = torch.cat(frames, dim=2) return frames # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs def check_inputs( self, prompt, height, width, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) @property def guidance_scale(self): return self._guidance_scale @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 480, width: int = 720, num_frames: int = 48, fps: int = 8, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, use_dynamic_cfg: bool = False, num_videos_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: str = "pil", return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 226, ) -> Union[CogVideoXPipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_frames (`int`, defaults to `48`): Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that needs to be satisfied is that of divisibility mentioned above. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int`, defaults to `226`): Maximum sequence length in encoded prompt. Must be consistent with `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. Examples: Returns: [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ assert ( num_frames <= 48 and num_frames % fps == 0 and fps == 8 ), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX." if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds, negative_prompt_embeds, ) self._guidance_scale = guidance_scale self._interrupt = False # 2. Default call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, negative_prompt, do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) self._num_timesteps = len(timesteps) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels num_frames += 1 latents = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: # for DPM-solver++ old_pred_original_sample = None for i, t in enumerate(timesteps): if self.interrupt: continue latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, return_dict=False, )[0] noise_pred = noise_pred.float() # perform guidance if use_dynamic_cfg: self._guidance_scale = 1 + guidance_scale * ( (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 ) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 if not isinstance(self.scheduler, CogVideoXDPMScheduler): latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] else: latents, old_pred_original_sample = self.scheduler.step( noise_pred, old_pred_original_sample, t, timesteps[i - 1] if i > 0 else None, latents, **extra_step_kwargs, return_dict=False, ) latents = latents.to(prompt_embeds.dtype) # call the callback, if provided if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if not output_type == "latent": video = self.decode_latents(latents, num_frames // fps) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return CogVideoXPipelineOutput(frames=video) ================================================ FILE: src/diffusers/pipelines/consistency_models/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, _LazyModule, ) _import_structure = { "pipeline_consistency_models": ["ConsistencyModelPipeline"], } if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_consistency_models import ConsistencyModelPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) ================================================ FILE: src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Union import torch from ...models import UNet2DModel from ...schedulers import CMStochasticIterativeScheduler from ...utils import ( logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import ConsistencyModelPipeline >>> device = "cuda" >>> # Load the cd_imagenet64_l2 checkpoint. >>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2" >>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) >>> pipe.to(device) >>> # Onestep Sampling >>> image = pipe(num_inference_steps=1).images[0] >>> image.save("cd_imagenet64_l2_onestep_sample.png") >>> # Onestep sampling, class-conditional image generation >>> # ImageNet-64 class label 145 corresponds to king penguins >>> image = pipe(num_inference_steps=1, class_labels=145).images[0] >>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png") >>> # Multistep sampling, class-conditional image generation >>> # Timesteps can be explicitly specified; the particular timesteps below are from the original GitHub repo: >>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 >>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0] >>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png") ``` """ class ConsistencyModelPipeline(DiffusionPipeline): r""" Pipeline for unconditional or class-conditional image generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: unet ([`UNet2DModel`]): A `UNet2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only compatible with [`CMStochasticIterativeScheduler`]. """ model_cpu_offload_seq = "unet" def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None: super().__init__() self.register_modules( unet=unet, scheduler=scheduler, ) self.safety_checker = None def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Follows diffusers.VaeImageProcessor.postprocess def postprocess_image(self, sample: torch.Tensor, output_type: str = "pil"): if output_type not in ["pt", "np", "pil"]: raise ValueError( f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" ) # Equivalent to diffusers.VaeImageProcessor.denormalize sample = (sample / 2 + 0.5).clamp(0, 1) if output_type == "pt": return sample # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "np": return sample # Output_type must be 'pil' sample = self.numpy_to_pil(sample) return sample def prepare_class_labels(self, batch_size, device, class_labels=None): if self.unet.config.num_class_embeds is not None: if isinstance(class_labels, list): class_labels = torch.tensor(class_labels, dtype=torch.int) elif isinstance(class_labels, int): assert batch_size == 1, "Batch size must be 1 if classes is an int" class_labels = torch.tensor([class_labels], dtype=torch.int) elif class_labels is None: # Randomly generate batch_size class labels # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) class_labels = class_labels.to(device) else: class_labels = None return class_labels def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps): if num_inference_steps is None and timesteps is None: raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") if num_inference_steps is not None and timesteps is not None: logger.warning( f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;" " `timesteps` will be used over `num_inference_steps`." ) if latents is not None: expected_shape = (batch_size, 3, img_size, img_size) if latents.shape != expected_shape: raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, batch_size: int = 1, class_labels: Optional[Union[torch.Tensor, List[int], int]] = None, num_inference_steps: int = 1, timesteps: List[int] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, ): r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*): Optional class labels for conditioning class-conditional consistency models. Not used if the model is not class-conditional. num_inference_steps (`int`, *optional*, defaults to 1): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ # 0. Prepare call parameters img_size = self.unet.config.sample_size device = self._execution_device # 1. Check inputs self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps) # 2. Prepare image latents # Sample image latents x_0 ~ N(0, sigma_0^2 * I) sample = self.prepare_latents( batch_size=batch_size, num_channels=self.unet.config.in_channels, height=img_size, width=img_size, dtype=self.unet.dtype, device=device, generator=generator, latents=latents, ) # 3. Handle class_labels for class-conditional models class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels) # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps # 5. Denoising loop # Multistep sampling: implements Algorithm 1 in the paper with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): scaled_sample = self.scheduler.scale_model_input(sample, t) model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0] sample = self.scheduler.step(model_output, t, sample, generator=generator)[0] # call the callback, if provided progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, sample) # 6. Post-process image sample image = self.postprocess_image(sample, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/controlnet/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_flax_available, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["multicontrolnet"] = ["MultiControlNetModel"] _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"] _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"] _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"] _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"] _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"] _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"] _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"] try: if not (is_transformers_available() and is_flax_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_flax_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) else: _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .multicontrolnet import MultiControlNetModel from .pipeline_controlnet import StableDiffusionControlNetPipeline from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline try: if not (is_transformers_available() and is_flax_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 else: from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/controlnet/multicontrolnet.py ================================================ import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn from ...models.controlnet import ControlNetModel, ControlNetOutput from ...models.modeling_utils import ModelMixin from ...utils import logging logger = logging.get_logger(__name__) class MultiControlNetModel(ModelMixin): r""" Multiple `ControlNetModel` wrapper class for Multi-ControlNet This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be compatible with `ControlNetModel`. Args: controlnets (`List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. You must set multiple `ControlNetModel` as a list. """ def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): super().__init__() self.nets = nn.ModuleList(controlnets) def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: List[torch.tensor], conditioning_scale: List[float], class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): down_samples, mid_sample = controlnet( sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states, controlnet_cond=image, conditioning_scale=scale, class_labels=class_labels, timestep_cond=timestep_cond, attention_mask=attention_mask, added_cond_kwargs=added_cond_kwargs, cross_attention_kwargs=cross_attention_kwargs, guess_mode=guess_mode, return_dict=return_dict, ) # merge samples if i == 0: down_block_res_samples, mid_block_res_sample = down_samples, mid_sample else: down_block_res_samples = [ samples_prev + samples_curr for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) ] mid_block_res_sample += mid_sample return down_block_res_samples, mid_block_res_sample def save_pretrained( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, save_function: Callable = None, safe_serialization: bool = True, variant: Optional[str] = None, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful on distributed training like TPUs when one need to replace `torch.save` by another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). variant (`str`, *optional*): If specified, weights are saved in the format pytorch_model..bin. """ for idx, controlnet in enumerate(self.nets): suffix = "" if idx == 0 else f"_{idx}" controlnet.save_pretrained( save_directory + suffix, is_main_process=is_main_process, save_function=save_function, safe_serialization=safe_serialization, variant=variant, ) @classmethod def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train the model, you should first set it back in training mode with `model.train()`. The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_path (`os.PathLike`): A path to a *directory* containing model weights saved using [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g., `./my_model_directory/controlnet`. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype will be automatically derived from the model's weights. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. variant (`str`, *optional*): If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is ignored when using `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. """ idx = 0 controlnets = [] # load controlnet and append to list until no controlnet directory exists anymore # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... model_path_to_load = pretrained_model_path while os.path.isdir(model_path_to_load): controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs) controlnets.append(controlnet) idx += 1 model_path_to_load = pretrained_model_path + f"_{idx}" logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") if len(controlnets) == 0: raise ValueError( f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." ) return cls(controlnets) ================================================ FILE: src/diffusers/pipelines/controlnet/pipeline_controlnet.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" ... ) >>> image = np.array(image) >>> # get canny image >>> image = cv2.Canny(image, 100, 200) >>> image = image[:, :, None] >>> image = np.concatenate([image, image, image], axis=2) >>> canny_image = Image.fromarray(image) >>> # load control net and stable diffusion v1-5 >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> # speed up diffusion process with faster scheduler and memory optimization >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) >>> # remove following line if xformers is not installed >>> pipe.enable_xformers_memory_efficient_attention() >>> pipe.enable_model_cpu_offload() >>> # generate image >>> generator = torch.manual_seed(0) >>> image = pipe( ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image ... ).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionControlNetPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): transposed_image = [list(t) for t in zip(*image)] if len(transposed_image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in transposed_image: self.check_image(image_, prompt, prompt_embeds) elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) else: for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError( "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " "The conditioning scale must be fixed across the batch." ) elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] if not isinstance(control_guidance_end, (tuple, list)): control_guidance_end = [control_guidance_end] if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): The ControlNet encoder tries to recognize the content of the input image even if you remove all prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] # Nested lists as ControlNet condition if isinstance(image[0], list): # Transpose the nested image list image = [list(t) for t in zip(*image)] for image_ in image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and self.do_classifier_free_guidance: # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py ================================================ # Copyright 2024 Salesforce.com, inc. # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import PIL.Image import torch from transformers import CLIPTokenizer from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import PNDMScheduler from ...utils import ( logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..blip_diffusion.blip_image_processing import BlipImageProcessor from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers.pipelines import BlipDiffusionControlNetPipeline >>> from diffusers.utils import load_image >>> from controlnet_aux import CannyDetector >>> import torch >>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained( ... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16 ... ).to("cuda") >>> style_subject = "flower" >>> tgt_subject = "teapot" >>> text_prompt = "on a marble table" >>> cldm_cond_image = load_image( ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg" ... ).resize((512, 512)) >>> canny = CannyDetector() >>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil") >>> style_image = load_image( ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg" ... ) >>> guidance_scale = 7.5 >>> num_inference_steps = 50 >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" >>> output = blip_diffusion_pipe( ... text_prompt, ... style_image, ... cldm_cond_image, ... style_subject, ... tgt_subject, ... guidance_scale=guidance_scale, ... num_inference_steps=num_inference_steps, ... neg_prompt=negative_prompt, ... height=512, ... width=512, ... ).images >>> output[0].save("image.png") ``` """ class BlipDiffusionControlNetPipeline(DiffusionPipeline): """ Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: tokenizer ([`CLIPTokenizer`]): Tokenizer for the text encoder text_encoder ([`ContextCLIPTextModel`]): Text encoder to encode the text prompt vae ([`AutoencoderKL`]): VAE model to map the latents to the image unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. scheduler ([`PNDMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. qformer ([`Blip2QFormerModel`]): QFormer model to get multi-modal embeddings from the text and image. controlnet ([`ControlNetModel`]): ControlNet model to get the conditioning image embedding. image_processor ([`BlipImageProcessor`]): Image Processor to preprocess and postprocess the image. ctx_begin_pos (int, `optional`, defaults to 2): Position of the context token in the text encoder. """ model_cpu_offload_seq = "qformer->text_encoder->unet->vae" def __init__( self, tokenizer: CLIPTokenizer, text_encoder: ContextCLIPTextModel, vae: AutoencoderKL, unet: UNet2DConditionModel, scheduler: PNDMScheduler, qformer: Blip2QFormerModel, controlnet: ControlNetModel, image_processor: BlipImageProcessor, ctx_begin_pos: int = 2, mean: List[float] = None, std: List[float] = None, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, unet=unet, scheduler=scheduler, qformer=qformer, controlnet=controlnet, image_processor=image_processor, ) self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) def get_query_embeddings(self, input_image, src_subject): return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) # from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): rv = [] for prompt, tgt_subject in zip(prompts, tgt_subjects): prompt = f"a {tgt_subject} {prompt.strip()}" # a trick to amplify the prompt rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) return rv # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def encode_prompt(self, query_embeds, prompt, device=None): device = device or self._execution_device # embeddings for prompt, with query_embeds as context max_len = self.text_encoder.text_model.config.max_position_embeddings max_len -= self.qformer.config.num_query_tokens tokenized_prompt = self.tokenizer( prompt, padding="max_length", truncation=True, max_length=max_len, return_tensors="pt", ).to(device) batch_size = query_embeds.shape[0] ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size text_embeddings = self.text_encoder( input_ids=tokenized_prompt.input_ids, ctx_embeddings=query_embeds, ctx_begin_pos=ctx_begin_pos, )[0] return text_embeddings # Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_control_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, ): image = self.image_processor.preprocess( image, size={"width": width, "height": height}, do_rescale=True, do_center_crop=False, do_normalize=False, return_tensors="pt", )["pixel_values"].to(device) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance: image = torch.cat([image] * 2) return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: List[str], reference_image: PIL.Image.Image, condtioning_image: PIL.Image.Image, source_subject_category: List[str], target_subject_category: List[str], latents: Optional[torch.Tensor] = None, guidance_scale: float = 7.5, height: int = 512, width: int = 512, num_inference_steps: int = 50, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, neg_prompt: Optional[str] = "", prompt_strength: float = 1.0, prompt_reps: int = 20, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`List[str]`): The prompt or prompts to guide the image generation. reference_image (`PIL.Image.Image`): The reference image to condition the generation on. condtioning_image (`PIL.Image.Image`): The conditioning canny edge image to condition the generation on. source_subject_category (`List[str]`): The source subject category. target_subject_category (`List[str]`): The target subject category. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by random sampling. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. height (`int`, *optional*, defaults to 512): The height of the generated image. width (`int`, *optional*, defaults to 512): The width of the generated image. seed (`int`, *optional*, defaults to 42): The seed to use for random generation. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. neg_prompt (`str`, *optional*, defaults to ""): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_strength (`float`, *optional*, defaults to 1.0): The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps to amplify the prompt. prompt_reps (`int`, *optional*, defaults to 20): The number of times the prompt is repeated along with prompt_strength to amplify the prompt. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device reference_image = self.image_processor.preprocess( reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" )["pixel_values"] reference_image = reference_image.to(device) if isinstance(prompt, str): prompt = [prompt] if isinstance(source_subject_category, str): source_subject_category = [source_subject_category] if isinstance(target_subject_category, str): target_subject_category = [target_subject_category] batch_size = len(prompt) prompt = self._build_prompt( prompts=prompt, tgt_subjects=target_subject_category, prompt_strength=prompt_strength, prompt_reps=prompt_reps, ) query_embeds = self.get_query_embeddings(reference_image, source_subject_category) text_embeddings = self.encode_prompt(query_embeds, prompt, device) # 3. unconditional embedding do_classifier_free_guidance = guidance_scale > 1.0 if do_classifier_free_guidance: max_length = self.text_encoder.text_model.config.max_position_embeddings uncond_input = self.tokenizer( [neg_prompt] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt", ) uncond_embeddings = self.text_encoder( input_ids=uncond_input.input_ids.to(device), ctx_embeddings=None, )[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) latents = self.prepare_latents( batch_size=batch_size, num_channels=self.unet.config.in_channels, height=height // scale_down_factor, width=width // scale_down_factor, generator=generator, latents=latents, dtype=self.unet.dtype, device=device, ) # set timesteps extra_set_kwargs = {} self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) cond_image = self.prepare_control_image( image=condtioning_image, width=width, height=height, batch_size=batch_size, num_images_per_prompt=1, device=device, dtype=self.controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, ) for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance do_classifier_free_guidance = guidance_scale > 1.0 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents down_block_res_samples, mid_block_res_sample = self.controlnet( latent_model_input, t, encoder_hidden_states=text_embeddings, controlnet_cond=cond_image, return_dict=False, ) noise_pred = self.unet( latent_model_input, timestep=t, encoder_hidden_states=text_embeddings, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, )["sample"] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = self.scheduler.step( noise_pred, t, latents, )["prev_sample"] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" ... ) >>> np_image = np.array(image) >>> # get canny image >>> np_image = cv2.Canny(np_image, 100, 200) >>> np_image = np_image[:, :, None] >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2) >>> canny_image = Image.fromarray(np_image) >>> # load control net and stable diffusion v1-5 >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> # speed up diffusion process with faster scheduler and memory optimization >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # generate image >>> generator = torch.manual_seed(0) >>> image = pipe( ... "futuristic-looking woman", ... num_inference_steps=20, ... generator=generator, ... image=image, ... control_image=canny_image, ... ).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") def prepare_image(image): if isinstance(image, torch.Tensor): # Batch single image if image.ndim == 3: image = image.unsqueeze(0) image = image.to(dtype=torch.float32) else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 return image class StableDiffusionControlNetImg2ImgPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, ): r""" Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_control_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.8, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The initial image to be used as the starting point for the image generation process. Can also accept image latents as `image`, and if passing latents directly they are not encoded again. control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): The ControlNet encoder tries to recognize the content of the input image even if you remove all prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, control_image, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) # 5. Prepare controlnet_conditioning_image if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) self._num_timesteps = len(timesteps) # 6. Prepare latent variables if latents is None: latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=control_image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and self.do_classifier_free_guidance: # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install transformers accelerate >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> init_image = load_image( ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" ... ) >>> init_image = init_image.resize((512, 512)) >>> generator = torch.Generator(device="cpu").manual_seed(1) >>> mask_image = load_image( ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" ... ) >>> mask_image = mask_image.resize((512, 512)) >>> def make_canny_condition(image): ... image = np.array(image) ... image = cv2.Canny(image, 100, 200) ... image = image[:, :, None] ... image = np.concatenate([image, image, image], axis=2) ... image = Image.fromarray(image) ... return image >>> control_image = make_canny_condition(init_image) >>> controlnet = ControlNetModel.from_pretrained( ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # generate image >>> image = pipe( ... "a handsome man with ray-ban sunglasses", ... num_inference_steps=20, ... generator=generator, ... eta=1.0, ... image=init_image, ... mask_image=mask_image, ... control_image=control_image, ... ).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") class StableDiffusionControlNetInpaintPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, ): r""" Pipeline for image inpainting using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting ([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as default text-to-image Stable Diffusion checkpoints ([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def check_inputs( self, prompt, image, mask_image, height, width, callback_steps, output_type, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, padding_mask_crop=None, ): if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( f"The mask image should be a PIL image when inpainting mask crop, but is of type" f" {type(mask_image)}." ) if output_type != "pil": raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) def prepare_control_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, crops_coords, resize_mode, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess( image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode ).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None, timestep=None, is_strength_max=True, return_noise=False, return_image_latents=False, ): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if (image is None or timestep is None) and not is_strength_max: raise ValueError( "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." "However, either the image or the noise timestep has not been provided." ) if return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) if image.shape[1] == 4: image_latents = image else: image_latents = self._encode_vae_image(image=image, generator=generator) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents else: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma outputs = (latents,) if return_noise: outputs += (noise,) if return_image_latents: outputs += (image_latents,) return outputs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) if masked_image.shape[1] == 4: masked_image_latents = masked_image else: masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = retrieve_latents(self.vae.encode(image), generator=generator) image_latents = self.vae.config.scaling_factor * image_latents return image_latents @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, mask_image: PipelineImageInput = None, control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, padding_mask_crop: Optional[int] = None, strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.5, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, NumPy array or tensor representing an image batch to be used as the starting point. For both NumPy array and PyTorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a NumPy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. It can also accept image latents as `image`, but if passing latents directly it is not encoded again. mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, NumPy array or tensor representing an image batch to mask `image`. White pixels in the mask are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel (luminance) before use. If it's a NumPy array or PyTorch tensor, it should contain one color channel (L) instead of 3, so the expected shape for PyTorch tensor would be `(B, 1, H, W)`, `(B, H, W)`, `(1, H, W)`, `(H, W)`. And for NumPy array, it would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 1)`, or `(H, W)`. control_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[List[torch.Tensor]]`, or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. padding_mask_crop (`int`, *optional*, defaults to `None`): The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large and contain information irrelevant for inpainting, such as background. strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): The ControlNet encoder tries to recognize the content of the input image even if you remove all prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, control_image, mask_image, height, width, callback_steps, output_type, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, padding_mask_crop, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if padding_mask_crop is not None: height, width = self.image_processor.get_default_height_width(image, height, width) crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) resize_mode = "fill" else: crops_coords = None resize_mode = "default" device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, crops_coords=crops_coords, resize_mode=resize_mode, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, crops_coords=crops_coords, resize_mode=resize_mode, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images else: assert False # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width original_image = image init_image = self.image_processor.preprocess( image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode ) init_image = init_image.to(dtype=torch.float32) mask = self.mask_processor.preprocess( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) masked_image = init_image * (mask < 0.5) _, _, height, width = init_image.shape # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device ) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, return_noise=True, return_image_latents=return_image_latents, ) if return_image_latents: latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, self.do_classifier_free_guidance, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=control_image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and self.do_classifier_free_guidance: # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual if num_channels_unet == 9: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: init_latents_proper = image_latents if self.do_classifier_free_guidance: init_mask, _ = mask.chunk(2) else: init_mask = mask if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if padding_mask_crop is not None: image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py ================================================ # Copyright 2024 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, is_invisible_watermark_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from .multicontrolnet import MultiControlNetModel if is_invisible_watermark_available(): from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install transformers accelerate >>> from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel, DDIMScheduler >>> from diffusers.utils import load_image >>> from PIL import Image >>> import numpy as np >>> import torch >>> init_image = load_image( ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" ... ) >>> init_image = init_image.resize((1024, 1024)) >>> generator = torch.Generator(device="cpu").manual_seed(1) >>> mask_image = load_image( ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" ... ) >>> mask_image = mask_image.resize((1024, 1024)) >>> def make_canny_condition(image): ... image = np.array(image) ... image = cv2.Canny(image, 100, 200) ... image = image[:, :, None] ... image = np.concatenate([image, image, image], axis=2) ... image = Image.fromarray(image) ... return image >>> control_image = make_canny_condition(init_image) >>> controlnet = ControlNetModel.from_pretrained( ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> # generate image >>> image = pipe( ... "a handsome man with ray-ban sunglasses", ... num_inference_steps=20, ... generator=generator, ... eta=1.0, ... image=init_image, ... mask_image=mask_image, ... control_image=control_image, ... ).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class StableDiffusionXLControlNetInpaintPipeline( DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin, TextualInversionLoaderMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "add_neg_time_ids", "mask", "masked_image_latents", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, feature_extractor: Optional[CLIPImageProcessor] = None, image_encoder: Optional[CLIPVisionModelWithProjection] = None, ): super().__init__() if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, controlnet=controlnet, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) def check_inputs( self, prompt, prompt_2, image, mask_image, strength, num_inference_steps, callback_steps, output_type, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, padding_mask_crop=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if num_inference_steps is None: raise ValueError("`num_inference_steps` cannot be None.") elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: raise ValueError( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( f"The mask image should be a PIL image when inpainting mask crop, but is of type" f" {type(mask_image)}." ) if output_type != "pil": raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] if not isinstance(control_guidance_end, (tuple, list)): control_guidance_end = [control_guidance_end] if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_control_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, crops_coords, resize_mode, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess( image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode ).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None, timestep=None, is_strength_max=True, add_noise=True, return_noise=False, return_image_latents=False, ): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if (image is None or timestep is None) and not is_strength_max: raise ValueError( "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." "However, either the image or the noise timestep has not been provided." ) if return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) if image.shape[1] == 4: image_latents = image else: image_latents = self._encode_vae_image(image=image, generator=generator) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) if latents is None and add_noise: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents elif add_noise: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma else: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = image_latents.to(device) outputs = (latents,) if return_noise: outputs += (noise,) if return_image_latents: outputs += (image_latents,) return outputs def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): dtype = image.dtype if self.vae.config.force_upcast: image = image.float() self.vae.to(dtype=torch.float32) if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) image_latents = image_latents.to(dtype) image_latents = self.vae.config.scaling_factor * image_latents return image_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = None if masked_image is not None: masked_image = masked_image.to(device=device, dtype=dtype) masked_image_latents = self._encode_vae_image(masked_image, generator=generator) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat( batch_size // masked_image_latents.shape[0], 1, 1, 1 ) masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) else: t_start = 0 timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] # Strength is irrelevant if we directly request a timestep to start at; # that is, strength is determined by the denoising_start instead. if denoising_start is not None: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() if self.scheduler.order == 2 and num_inference_steps % 2 == 0: # if the scheduler is a 2nd order scheduler we might have to do +1 # because `num_inference_steps` might be even given that every timestep # (except the highest one) is duplicated. If `num_inference_steps` is even it would # mean that we cut the timesteps in the middle of the denoising step # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler num_inference_steps = num_inference_steps + 1 # because t_n+1 >= t_n, we slice the timesteps starting from the end timesteps = timesteps[-num_inference_steps:] return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype, text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." ) elif expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) return add_time_ids, add_neg_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, mask_image: PipelineImageInput = None, control_image: Union[ PipelineImageInput, List[PipelineImageInput], ] = None, height: Optional[int] = None, width: Optional[int] = None, padding_mask_crop: Optional[int] = None, strength: float = 0.9999, num_inference_steps: int = 50, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, guidance_rescale: float = 0.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. padding_mask_crop (`int`, *optional*, defaults to `None`): The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large and contain information irrelevant for inpainting, such as background. strength (`float`, *optional*, defaults to 0.9999): Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked portion of the reference `image`. Note that in the case of `denoising_start` being declared as an integer, the value of `strength` will be ignored. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). aesthetic_score (`float`, *optional*, defaults to 6.0): Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_aesthetic_score (`float`, *optional*, defaults to 2.5): Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # # 0.0 Default height and width to unet # height = height or self.unet.config.sample_size * self.vae_scale_factor # width = width or self.unet.config.sample_size * self.vae_scale_factor # 0.1 align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs self.check_inputs( prompt, prompt_2, control_image, mask_image, strength, num_inference_steps, callback_steps, output_type, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, padding_mask_crop, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # 3.1 Encode ip_adapter_image if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. set timesteps def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, ) # check that number of inference steps is not < 1 - as this doesn't make sense if num_inference_steps < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 self._num_timesteps = len(timesteps) # 5. Preprocess mask and image - resizes image and mask w.r.t height and width # 5.1 Prepare init image if padding_mask_crop is not None: height, width = self.image_processor.get_default_height_width(image, height, width) crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) resize_mode = "fill" else: crops_coords = None resize_mode = "default" original_image = image init_image = self.image_processor.preprocess( image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode ) init_image = init_image.to(dtype=torch.float32) # 5.2 Prepare control images if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, crops_coords=crops_coords, resize_mode=resize_mode, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, crops_coords=crops_coords, resize_mode=resize_mode, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images else: raise ValueError(f"{controlnet.__class__} is not supported.") # 5.3 Prepare mask mask = self.mask_processor.preprocess( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) masked_image = init_image * (mask < 0.5) _, _, height, width = init_image.shape # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 add_noise = True if denoising_start is None else False latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, add_noise=add_noise, return_noise=True, return_image_latents=return_image_latents, ) if return_image_latents: latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, self.do_classifier_free_guidance, ) # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: raise ValueError( f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." ) # 8.1 Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps if isinstance(controlnet, MultiControlNetModel) else keeps[0]) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline height, width = latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) # 11. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) if ( denoising_end is not None and denoising_start is not None and denoising_value_valid(denoising_end) and denoising_value_valid(denoising_start) and denoising_start >= denoising_end ): raise ValueError( f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: " + f" {denoising_end} when using type float." ) elif denoising_end is not None and denoising_value_valid(denoising_end): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] controlnet_added_cond_kwargs = { "text_embeds": add_text_embeds.chunk(2)[1], "time_ids": add_time_ids.chunk(2)[1], } else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] # # Resize control_image to match the size of the input to the controlnet # if control_image.shape[-2:] != control_model_input.shape[-2:]: # control_image = F.interpolate(control_image, size=control_model_input.shape[-2:], mode="bilinear", align_corners=False) down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=control_image, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) if guess_mode and self.do_classifier_free_guidance: # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds if num_channels_unet == 9: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: init_latents_proper = image_latents if self.do_classifier_free_guidance: init_mask, _ = mask.chunk(2) else: init_mask = mask if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: return StableDiffusionXLPipelineOutput(images=latents) # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) if padding_mask_crop is not None: image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from diffusers.utils.import_utils import is_invisible_watermark_available from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" >>> negative_prompt = "low quality, bad quality, sketches" >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" ... ) >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization >>> controlnet = ControlNetModel.from_pretrained( ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 ... ) >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) >>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> # get canny image >>> image = np.array(image) >>> image = cv2.Canny(image, 100, 200) >>> image = image[:, :, None] >>> image = np.concatenate([image, image, image], axis=2) >>> canny_image = Image.fromarray(image) >>> # generate image >>> image = pipe( ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image ... ).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionXLControlNetPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): Second frozen text-encoder ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. tokenizer_2 ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings should always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no watermarker is used. """ # leave controlnet out on purpose because it iterates with unet model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "feature_extractor", "image_encoder", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, controlnet=controlnet, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, prompt_2, image, callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, negative_pooled_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] if not isinstance(control_guidance_end, (tuple, list)): control_guidance_end = [control_guidance_end] if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled text embeddings are generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): The ControlNet encoder tries to recognize the content of the input image even if you remove all prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned containing the output images. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, image, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, negative_pooled_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3.1 Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt, prompt_2, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # 3.2 Encode ip_adapter_image if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] for image_ in image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 7.2 Prepare added time ids & embeddings if isinstance(image, list): original_size = original_size or image[0].shape[-2:] else: original_size = original_size or image.shape[-2:] target_size = target_size or (height, width) add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 8.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] controlnet_added_cond_kwargs = { "text_embeds": add_text_embeds.chunk(2)[1], "time_ids": add_time_ids.chunk(2)[1], } else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) if guess_mode and self.do_classifier_free_guidance: # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from diffusers.utils.import_utils import is_invisible_watermark_available from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # pip install accelerate transformers safetensors diffusers >>> import torch >>> import numpy as np >>> from PIL import Image >>> from transformers import DPTImageProcessor, DPTForDepthEstimation >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( ... "diffusers/controlnet-depth-sdxl-1.0-small", ... variant="fp16", ... use_safetensors=True, ... torch_dtype=torch.float16, ... ) >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) >>> pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", ... controlnet=controlnet, ... vae=vae, ... variant="fp16", ... use_safetensors=True, ... torch_dtype=torch.float16, ... ) >>> pipe.enable_model_cpu_offload() >>> def get_depth_map(image): ... image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") ... with torch.no_grad(), torch.autocast("cuda"): ... depth_map = depth_estimator(image).predicted_depth ... depth_map = torch.nn.functional.interpolate( ... depth_map.unsqueeze(1), ... size=(1024, 1024), ... mode="bicubic", ... align_corners=False, ... ) ... depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) ... depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) ... depth_map = (depth_map - depth_min) / (depth_max - depth_min) ... image = torch.cat([depth_map] * 3, dim=1) ... image = image.permute(0, 2, 3, 1).cpu().numpy()[0] ... image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) ... return image >>> prompt = "A robot, 4k photo" >>> image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ).resize((1024, 1024)) >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization >>> depth_image = get_depth_map(image) >>> images = pipe( ... prompt, ... image=image, ... control_image=depth_image, ... strength=0.99, ... num_inference_steps=50, ... controlnet_conditioning_scale=controlnet_conditioning_scale, ... ).images >>> images[0].save(f"robot_cat.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") class StableDiffusionXLControlNetImg2ImgPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin, ): r""" Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the config of `stabilityai/stable-diffusion-xl-refiner-1-0`. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "feature_extractor", "image_encoder", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "add_neg_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, controlnet=controlnet, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, prompt_2, image, strength, num_inference_steps, callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if num_inference_steps is None: raise ValueError("`num_inference_steps` cannot be None.") elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: raise ValueError( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] if not isinstance(control_guidance_end, (tuple, list)): control_guidance_end = [control_guidance_end] if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image def prepare_control_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents def prepare_latents( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) latents_mean = latents_std = None if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") torch.cuda.empty_cache() image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() self.vae.to(dtype=torch.float32) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) init_latents = init_latents.to(dtype) if latents_mean is not None and latents_std is not None: latents_mean = latents_mean.to(device=device, dtype=dtype) latents_std = latents_std.to(device=device, dtype=dtype) init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std else: init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) if add_noise: shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype, text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) add_neg_time_ids = list( negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." ) elif expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) return add_time_ids, add_neg_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.8, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The initial image will be used as the starting point for the image generation process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in init, images must be passed as a list such that each element of the list can be correctly batched for input to a single controlnet. height (`int`, *optional*, defaults to the size of control_image): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to the size of control_image): The width in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the controlnet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the controlnet stops applying. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. aesthetic_score (`float`, *optional*, defaults to 6.0): Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_aesthetic_score (`float`, *optional*, defaults to 2.5): Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` containing the output images. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, control_image, strength, num_inference_steps, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3.1. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt, prompt_2, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # 3.2 Encode ip_adapter_image if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare image and controlnet_conditioning_image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = control_image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images height, width = control_image[0].shape[-2:] else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) self._num_timesteps = len(timesteps) # 6. Prepare latent variables if latents is None: latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, True, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 7.2 Prepare added time ids & embeddings if isinstance(control_image, list): original_size = original_size or control_image[0].shape[-2:] else: original_size = original_size or control_image.shape[-2:] target_size = target_size or (height, width) if negative_original_size is None: negative_original_size = original_size if negative_target_size is None: negative_target_size = target_size add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] controlnet_added_cond_kwargs = { "text_embeds": add_text_embeds.chunk(2)[1], "time_ids": add_time_ids.chunk(2)[1], } else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = added_cond_kwargs if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=control_image, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) if guess_mode and self.do_classifier_free_guidance: # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents return StableDiffusionXLPipelineOutput(images=image) # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from PIL import Image from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from ..stable_diffusion import FlaxStableDiffusionPipelineOutput from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> import jax.numpy as jnp >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> from diffusers.utils import load_image, make_image_grid >>> from PIL import Image >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel >>> def create_key(seed=0): ... return jax.random.PRNGKey(seed) >>> rng = create_key(0) >>> # get canny image >>> canny_image = load_image( ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" ... ) >>> prompts = "best quality, extremely detailed" >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" >>> # load control net and stable diffusion v1-5 >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 ... ) >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32 ... ) >>> params["controlnet"] = controlnet_params >>> num_samples = jax.device_count() >>> rng = jax.random.split(rng, jax.device_count()) >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) >>> p_params = replicate(params) >>> prompt_ids = shard(prompt_ids) >>> negative_prompt_ids = shard(negative_prompt_ids) >>> processed_image = shard(processed_image) >>> output = pipe( ... prompt_ids=prompt_ids, ... image=processed_image, ... params=p_params, ... prng_seed=rng, ... num_inference_steps=50, ... neg_prompt_ids=negative_prompt_ids, ... jit=True, ... ).images >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) >>> output_images = make_image_grid(output_images, num_samples // 4, 4) >>> output_images.save("generated_image.png") ``` """ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): r""" Flax-based pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.FlaxCLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`FlaxUNet2DConditionModel`]): A `FlaxUNet2DConditionModel` to denoise the encoded image latents. controlnet ([`FlaxControlNetModel`]: Provides additional conditioning to the `unet` during the denoising process. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, controlnet: FlaxControlNetModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_text_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): if not isinstance(image, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(image, Image.Image): image = [image] processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) return processed_images def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def _generate( self, prompt_ids: jnp.ndarray, image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int, guidance_scale: float, latents: Optional[jnp.ndarray] = None, neg_prompt_ids: Optional[jnp.ndarray] = None, controlnet_conditioning_scale: float = 1.0, ): height, width = image.shape[-2:] if height % 64 != 0 or width % 64 != 0: raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) image = jnp.concatenate([image] * 2) latents_shape = ( batch_size, self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") def loop_body(step, args): latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) down_block_res_samples, mid_block_res_sample = self.controlnet.apply( {"params": params["controlnet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, return_dict=False, ) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.ndarray, image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int = 50, guidance_scale: Union[float, jnp.ndarray] = 7.5, latents: jnp.ndarray = None, neg_prompt_ids: jnp.ndarray = None, controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0, return_dict: bool = True, jit: bool = False, ): r""" The call function to the pipeline for generation. Args: prompt_ids (`jnp.ndarray`): The prompt or prompts to guide the image generation. image (`jnp.ndarray`): Array representing the ControlNet input condition to provide guidance to the `unet` for generation. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights. prng_seed (`jax.Array`): Array containing random number generator key. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. latents (`jnp.ndarray`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents array is generated by sampling using the supplied random `generator`. controlnet_conditioning_scale (`float` or `jnp.ndarray`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ height, width = image.shape[-2:] if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] if isinstance(controlnet_conditioning_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] if jit: images = _p_generate( self, prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ) else: images = self._generate( prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.array(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, num_inference_steps. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), static_broadcasted_argnums=(0, 5), ) def _p_generate( pipe, prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ): return pipe._generate( prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) def preprocess(image, dtype): image = image.convert("RGB") w, h = image.size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) return image ================================================ FILE: src/diffusers/pipelines/controlnet_hunyuandit/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_hunyuandit_controlnet"] = ["HunyuanDiTControlNetPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_hunyuandit_controlnet import HunyuanDiTControlNetPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py ================================================ # Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...models import AutoencoderKL, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel from ...models.embeddings import get_2d_rotary_pos_embed from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import DDPMScheduler from ...utils import ( is_torch_xla_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py from diffusers import HunyuanDiT2DControlNetModel, HunyuanDiTControlNetPipeline import torch controlnet = HunyuanDiT2DControlNetModel.from_pretrained( "Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16 ) pipe = HunyuanDiTControlNetPipeline.from_pretrained( "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16 ) pipe.to("cuda") from diffusers.utils import load_image cond_image = load_image( "https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true" ) ## You may also use English prompt as HunyuanDiT supports both English and Chinese prompt = "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" # prompt="At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere." image = pipe( prompt, height=1024, width=1024, control_image=cond_image, num_inference_steps=50, ).images[0] ``` """ STANDARD_RATIO = np.array( [ 1.0, # 1:1 4.0 / 3.0, # 4:3 3.0 / 4.0, # 3:4 16.0 / 9.0, # 16:9 9.0 / 16.0, # 9:16 ] ) STANDARD_SHAPE = [ [(1024, 1024), (1280, 1280)], # 1:1 [(1024, 768), (1152, 864), (1280, 960)], # 4:3 [(768, 1024), (864, 1152), (960, 1280)], # 3:4 [(1280, 768)], # 16:9 [(768, 1280)], # 9:16 ] STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] SUPPORTED_SHAPE = [ (1024, 1024), (1280, 1280), # 1:1 (1024, 768), (1152, 864), (1280, 960), # 4:3 (768, 1024), (864, 1152), (960, 1280), # 3:4 (1280, 768), # 16:9 (768, 1280), # 9:16 ] def map_to_standard_shapes(target_width, target_height): target_ratio = target_width / target_height closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] return width, height def get_resize_crop_region_for_grid(src, tgt_size): th = tw = tgt_size h, w = src r = h / w # resize if r > 1: resize_height = th resize_width = int(round(th / h * w)) else: resize_width = tw resize_height = int(round(tw / w * h)) crop_top = int(round((th - resize_height) / 2.0)) crop_left = int(round((tw - resize_width) / 2.0)) return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class HunyuanDiTControlNetPipeline(DiffusionPipeline): r""" Pipeline for English/Chinese-to-image generation using HunyuanDiT. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by ourselves) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use `sdxl-vae-fp16-fix`. text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). HunyuanDiT uses a fine-tuned [bilingual CLIP]. tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): A `BertTokenizer` or `CLIPTokenizer` to tokenize text. transformer ([`HunyuanDiT2DModel`]): The HunyuanDiT model designed by Tencent Hunyuan. text_encoder_2 (`T5EncoderModel`): The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. tokenizer_2 (`MT5Tokenizer`): The tokenizer for the mT5 embedder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. controlnet ([`HunyuanDiT2DControlNetModel`] or `List[HunyuanDiT2DControlNetModel]` or [`HunyuanDiT2DControlNetModel`]): Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = [ "safety_checker", "feature_extractor", "text_encoder_2", "tokenizer_2", "text_encoder", "tokenizer", ] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "prompt_embeds_2", "negative_prompt_embeds_2", ] def __init__( self, vae: AutoencoderKL, text_encoder: BertModel, tokenizer: BertTokenizer, transformer: HunyuanDiT2DModel, scheduler: DDPMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, controlnet: Union[ HunyuanDiT2DControlNetModel, List[HunyuanDiT2DControlNetModel], Tuple[HunyuanDiT2DControlNetModel], HunyuanDiT2DMultiControlNetModel, ], text_encoder_2=T5EncoderModel, tokenizer_2=MT5Tokenizer, requires_safety_checker: bool = True, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, text_encoder_2=text_encoder_2, controlnet=controlnet, ) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.default_sample_size = ( self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128 ) # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt def encode_prompt( self, prompt: str, device: torch.device = None, dtype: torch.dtype = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: Optional[int] = None, text_encoder_index: int = 0, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device dtype (`torch.dtype`): torch dtype num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the prompt. Required when `prompt_embeds` is passed directly. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. text_encoder_index (`int`, *optional*): Index of the text encoder to use. `0` for clip and `1` for T5. """ if dtype is None: if self.text_encoder_2 is not None: dtype = self.text_encoder_2.dtype elif self.transformer is not None: dtype = self.transformer.dtype else: dtype = None if device is None: device = self._execution_device tokenizers = [self.tokenizer, self.tokenizer_2] text_encoders = [self.text_encoder, self.text_encoder_2] tokenizer = tokenizers[text_encoder_index] text_encoder = text_encoders[text_encoder_index] if max_sequence_length is None: if text_encoder_index == 0: max_length = 77 if text_encoder_index == 1: max_length = 256 else: max_length = max_sequence_length if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = text_encoder( text_input_ids.to(device), attention_mask=prompt_attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs def check_inputs( self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, prompt_embeds_2=None, negative_prompt_embeds_2=None, prompt_attention_mask_2=None, negative_prompt_attention_mask_2=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is None and prompt_embeds_2 is None: raise ValueError( "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: raise ValueError( "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: raise ValueError( "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" f" {negative_prompt_embeds_2.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): if isinstance(image, torch.Tensor): pass else: image = self.image_processor.preprocess(image, height=height, width=width) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 5.0, control_image: PipelineImageInput = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_2: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_2: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, prompt_attention_mask_2: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = (1024, 1024), target_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), use_resolution_binning: bool = True, ): r""" The call function to the pipeline for generation with HunyuanDiT. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`): The height in pixels of the generated image. width (`int`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. prompt_embeds_2 (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. negative_prompt_embeds_2 (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the prompt. Required when `prompt_embeds` is passed directly. prompt_attention_mask_2 (`torch.Tensor`, *optional*): Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A callback function or a list of callback functions to be called at the end of each denoising step. callback_on_step_end_tensor_inputs (`List[str]`, *optional*): A list of tensor inputs that should be passed to the callback function. If not defined, all tensor inputs will be passed. guidance_rescale (`float`, *optional*, defaults to 0.0): Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): The original size of the image. Used to calculate the time ids. target_size (`Tuple[int, int]`, *optional*): The target size of the image. Used to calculate the time ids. crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): The top left coordinates of the crop. Used to calculate the time ids. use_resolution_binning (`bool`, *optional*, defaults to `True`): Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. default height and width height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor height = int((height // 16) * 16) width = int((width // 16) * 16) if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: width, height = map_to_standard_shapes(width, height) height = int(height) width = int(width) logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, prompt_embeds_2, negative_prompt_embeds_2, prompt_attention_mask_2, negative_prompt_attention_mask_2, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt ( prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, device=device, dtype=self.transformer.dtype, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=77, text_encoder_index=0, ) ( prompt_embeds_2, negative_prompt_embeds_2, prompt_attention_mask_2, negative_prompt_attention_mask_2, ) = self.encode_prompt( prompt=prompt, device=device, dtype=self.transformer.dtype, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds_2, negative_prompt_embeds=negative_prompt_embeds_2, prompt_attention_mask=prompt_attention_mask_2, negative_prompt_attention_mask=negative_prompt_attention_mask_2, max_sequence_length=256, text_encoder_index=1, ) # 4. Prepare control image if isinstance(self.controlnet, HunyuanDiT2DControlNetModel): control_image = self.prepare_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=False, ) height, width = control_image.shape[-2:] control_image = self.vae.encode(control_image).latent_dist.sample() control_image = control_image * self.vae.config.scaling_factor elif isinstance(self.controlnet, HunyuanDiT2DMultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=False, ) control_image_ = self.vae.encode(control_image_).latent_dist.sample() control_image_ = control_image_ * self.vae.config.scaling_factor control_images.append(control_image_) control_image = control_images else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. create image_rotary_emb, style embedding & time ids grid_height = height // 8 // self.transformer.config.patch_size grid_width = width // 8 // self.transformer.config.patch_size base_size = 512 // 8 // self.transformer.config.patch_size grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) image_rotary_emb = get_2d_rotary_pos_embed( self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) ) style = torch.tensor([0], device=device) target_size = target_size or (height, width) add_time_ids = list(original_size + target_size + crops_coords_top_left) add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) add_time_ids = torch.cat([add_time_ids] * 2, dim=0) style = torch.cat([style] * 2, dim=0) prompt_embeds = prompt_embeds.to(device=device) prompt_attention_mask = prompt_attention_mask.to(device=device) prompt_embeds_2 = prompt_embeds_2.to(device=device) prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( batch_size * num_images_per_prompt, 1 ) style = style.to(device=device).repeat(batch_size * num_images_per_prompt) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( dtype=latent_model_input.dtype ) # controlnet(s) inference control_block_samples = self.controlnet( latent_model_input, t_expand, encoder_hidden_states=prompt_embeds, text_embedding_mask=prompt_attention_mask, encoder_hidden_states_t5=prompt_embeds_2, text_embedding_mask_t5=prompt_attention_mask_2, image_meta_size=add_time_ids, style=style, image_rotary_emb=image_rotary_emb, return_dict=False, controlnet_cond=control_image, conditioning_scale=controlnet_conditioning_scale, )[0] # predict the noise residual noise_pred = self.transformer( latent_model_input, t_expand, encoder_hidden_states=prompt_embeds, text_embedding_mask=prompt_attention_mask, encoder_hidden_states_t5=prompt_embeds_2, text_embedding_mask_t5=prompt_attention_mask_2, image_meta_size=add_time_ids, style=style, image_rotary_emb=image_rotary_emb, return_dict=False, controlnet_block_samples=control_block_samples, )[0] noise_pred, _ = noise_pred.chunk(2, dim=1) # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) negative_prompt_embeds_2 = callback_outputs.pop( "negative_prompt_embeds_2", negative_prompt_embeds_2 ) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/controlnet_sd3/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_flax_available, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline try: if not (is_transformers_available() and is_flax_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py ================================================ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, ) from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusion3ControlNetPipeline >>> from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel >>> from diffusers.utils import load_image >>> controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16) >>> pipe = StableDiffusion3ControlNetPipeline.from_pretrained( ... "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") >>> prompt = "A girl holding a sign that says InstantX" >>> image = pipe(prompt, control_image=control_image, controlnet_conditioning_scale=0.7).images[0] >>> image.save("sd3.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): r""" Args: transformer ([`SD3Transformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` as its dimension. text_encoder_2 ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. text_encoder_3 ([`T5EncoderModel`]): Frozen text-encoder. Stable Diffusion 3 uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). controlnet ([`SD3ControlNetModel`] or `List[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]): Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__( self, transformer: SD3Transformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_encoder_2: CLIPTextModelWithProjection, tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, controlnet: Union[ SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel ], ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, controlnet=controlnet, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) self.default_sample_size = ( self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128 ) # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if self.text_encoder_3 is None: return torch.zeros( ( batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim, ), device=device, dtype=dtype, ) text_inputs = self.tokenizer_3( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] dtype = self.text_encoder_3.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clip_skip: Optional[int] = None, clip_model_index: int = 0, ): device = device or self._execution_device clip_tokenizers = [self.tokenizer, self.tokenizer_2] clip_text_encoders = [self.text_encoder, self.text_encoder_2] tokenizer = clip_tokenizers[clip_model_index] text_encoder = clip_text_encoders[clip_model_index] prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, max_sequence_length: int = 256, lora_scale: Optional[float] = None, ): r""" Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in all text-encoders prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 prompt_3 = prompt_3 or prompt prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=0, ) prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( prompt=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=1, ) clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) t5_prompt_embed = self._get_t5_prompt_embeds( prompt=prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) clip_prompt_embeds = torch.nn.functional.pad( clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_3 = negative_prompt_3 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) negative_prompt_3 = ( batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( negative_prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=0, ) negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( negative_prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=1, ) negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) t5_negative_prompt_embed = self._get_t5_prompt_embeds( prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) negative_clip_prompt_embeds = torch.nn.functional.pad( negative_clip_prompt_embeds, (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), ) negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) negative_pooled_prompt_embeds = torch.cat( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def check_inputs( self, prompt, prompt_2, prompt_3, height, width, negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_3 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_3 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): if latents is not None: return latents.to(device=device, dtype=dtype) shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): if isinstance(image, torch.Tensor): pass else: image = self.image_processor.preprocess(image, height=height, width=width) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, control_image: PipelineImageInput = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, controlnet_pooled_projections: Optional[torch.FloatTensor] = None, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is will be used instead height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. controlnet_pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of controlnet input conditions. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used instead negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used instead num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(self.controlnet.nets) if isinstance(self.controlnet, SD3MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, prompt_3, height, width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device dtype = self.transformer.dtype ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 3. Prepare control image if isinstance(self.controlnet, SD3ControlNetModel): control_image = self.prepare_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=False, ) height, width = control_image.shape[-2:] control_image = self.vae.encode(control_image).latent_dist.sample() control_image = control_image * self.vae.config.scaling_factor elif isinstance(self.controlnet, SD3MultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=False, ) control_image_ = self.vae.encode(control_image_).latent_dist.sample() control_image_ = control_image_ * self.vae.config.scaling_factor control_images.append(control_image_) control_image = control_images else: assert False if controlnet_pooled_projections is None: controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds) else: controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps) # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] # controlnet(s) inference control_block_samples = self.controlnet( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=controlnet_pooled_projections, joint_attention_kwargs=self.joint_attention_kwargs, controlnet_cond=control_image, conditioning_scale=cond_scale, return_dict=False, )[0] noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, block_controlnet_hidden_states=control_block_samples, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusion3PipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/controlnet_xs/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_flax_available, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"] _import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"] try: if not (is_transformers_available() and is_flax_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_flax_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) else: pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline try: if not (is_transformers_available() and is_flax_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 else: pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" >>> negative_prompt = "low quality, bad quality, sketches" >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" ... ) >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 >>> controlnet = ControlNetXSAdapter.from_pretrained( ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> # get canny image >>> image = np.array(image) >>> image = cv2.Canny(image, 100, 200) >>> image = image[:, :, None] >>> image = np.concatenate([image, image, image], axis=2) >>> canny_image = Image.fromarray(image) >>> # generate image >>> image = pipe( ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image ... ).images[0] ``` """ class StableDiffusionControlNetXSPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. controlnet ([`ControlNetXSAdapter`]): A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: Union[UNet2DConditionModel, UNetControlNetXSModel], controlnet: ControlNetXSAdapter, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if isinstance(unet, UNet2DConditionModel): unet = UNetControlNetXSModel.from_unet(unet, controlnet) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Check `image` and `controlnet_conditioning_scale` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.unet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.unet, UNetControlNetXSModel) or is_compiled and isinstance(self.unet._orig_mod, UNetControlNetXSModel) ): self.check_image(image, prompt, prompt_embeds) if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") else: assert False start, end = control_guidance_start, control_guidance_end if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale def guidance_scale(self): return self._guidance_scale @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip def clip_skip(self): return self._clip_skip @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs def cross_attention_kwargs(self): return self._cross_attention_kwargs @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, control_guidance_start: float = 0.0, control_guidance_end: float = 1.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeine class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=unet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, ) height, width = image.shape[-2:] # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) is_controlnet_compiled = is_compiled_module(self.unet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if is_controlnet_compiled and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual apply_control = ( i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end ) noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds, controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, return_dict=True, apply_control=apply_control, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, ) from diffusers.utils.import_utils import is_invisible_watermark_available from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" >>> negative_prompt = "low quality, bad quality, sketches" >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" ... ) >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) >>> controlnet = ControlNetXSAdapter.from_pretrained( ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> # get canny image >>> image = np.array(image) >>> image = cv2.Canny(image, 100, 200) >>> image = image[:, :, None] >>> image = np.concatenate([image, image, image], axis=2) >>> canny_image = Image.fromarray(image) >>> # generate image >>> image = pipe( ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image ... ).images[0] ``` """ class StableDiffusionXLControlNetXSPipeline( DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): Second frozen text-encoder ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. tokenizer_2 ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. controlnet ([`ControlNetXSAdapter`]): A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings should always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no watermarker is used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: Union[UNet2DConditionModel, UNetControlNetXSModel], controlnet: ControlNetXSAdapter, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, feature_extractor: CLIPImageProcessor = None, ): super().__init__() if isinstance(unet, UNet2DConditionModel): unet = UNetControlNetXSModel.from_unet(unet, controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, controlnet=controlnet, scheduler=scheduler, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, prompt_2, image, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) # Check `image` and ``controlnet_conditioning_scale`` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.unet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.unet, UNetControlNetXSModel) or is_compiled and isinstance(self.unet._orig_mod, UNetControlNetXSModel) ): self.check_image(image, prompt, prompt_embeds) if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") else: assert False start, end = control_guidance_start, control_guidance_end if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale def guidance_scale(self): return self._guidance_scale @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip def clip_skip(self): return self._clip_skip @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs def cross_attention_kwargs(self): return self._cross_attention_kwargs @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, control_guidance_start: float = 0.0, control_guidance_end: float = 1.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled text embeddings are generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. control_guidance_start (`float`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeine class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is returned, otherwise a `tuple` is returned containing the output images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, image, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt, prompt_2, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # 4. Prepare image if isinstance(unet, UNetControlNetXSModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=unet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, ) height, width = image.shape[-2:] else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Prepare added time ids & embeddings if isinstance(image, list): original_size = original_size or image[0].shape[-2:] else: original_size = original_size or image.shape[-2:] target_size = target_size or (height, width) add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) is_controlnet_compiled = is_compiled_module(self.unet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if is_controlnet_compiled and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # predict the noise residual apply_control = ( i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end ) noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds, controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=True, apply_control=apply_control, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # manually for max memory savings if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/dance_diffusion/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule _import_structure = {"pipeline_dance_diffusion": ["DanceDiffusionPipeline"]} if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_dance_diffusion import DanceDiffusionPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) ================================================ FILE: src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...utils import logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name class DanceDiffusionPipeline(DiffusionPipeline): r""" Pipeline for audio generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: unet ([`UNet1DModel`]): A `UNet1DModel` to denoise the encoded audio. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of [`IPNDMScheduler`]. """ model_cpu_offload_seq = "unet" def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 100, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, audio_length_in_s: Optional[float] = None, return_dict: bool = True, ) -> Union[AudioPipelineOutput, Tuple]: r""" The call function to the pipeline for generation. Args: batch_size (`int`, *optional*, defaults to 1): The number of audio samples to generate. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher-quality audio sample at the expense of slower inference. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): The length of the generated audio sample in seconds. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. Example: ```py from diffusers import DiffusionPipeline from scipy.io.wavfile import write model_id = "harmonai/maestro-150k" pipe = DiffusionPipeline.from_pretrained(model_id) pipe = pipe.to("cuda") audios = pipe(audio_length_in_s=4.0).audios # To save locally for i, audio in enumerate(audios): write(f"maestro_test_{i}.wav", pipe.unet.sample_rate, audio.transpose()) # To dislay in google colab import IPython.display as ipd for audio in audios: display(ipd.Audio(audio, rate=pipe.unet.sample_rate)) ``` Returns: [`~pipelines.AudioPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated audio. """ if audio_length_in_s is None: audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate sample_size = audio_length_in_s * self.unet.config.sample_rate down_scale_factor = 2 ** len(self.unet.up_blocks) if sample_size < 3 * down_scale_factor: raise ValueError( f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" f" {3 * down_scale_factor / self.unet.config.sample_rate}." ) original_sample_size = int(sample_size) if sample_size % down_scale_factor != 0: sample_size = ( (audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1 ) * down_scale_factor logger.info( f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled" f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising" " process." ) sample_size = int(sample_size) dtype = next(self.unet.parameters()).dtype shape = (batch_size, self.unet.config.in_channels, sample_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) audio = randn_tensor(shape, generator=generator, device=self._execution_device, dtype=dtype) # set step values self.scheduler.set_timesteps(num_inference_steps, device=audio.device) self.scheduler.timesteps = self.scheduler.timesteps.to(dtype) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(audio, t).sample # 2. compute previous audio sample: x_t -> t_t-1 audio = self.scheduler.step(model_output, t, audio).prev_sample audio = audio.clamp(-1, 1).float().cpu().numpy() audio = audio[:, :, :original_sample_size] if not return_dict: return (audio,) return AudioPipelineOutput(audios=audio) ================================================ FILE: src/diffusers/pipelines/ddim/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule _import_structure = {"pipeline_ddim": ["DDIMPipeline"]} if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_ddim import DDIMPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) ================================================ FILE: src/diffusers/pipelines/ddim/pipeline_ddim.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...schedulers import DDIMScheduler from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DDIMPipeline(DiffusionPipeline): r""" Pipeline for image generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: unet ([`UNet2DModel`]): A `UNet2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of [`DDPMScheduler`], or [`DDIMScheduler`]. """ model_cpu_offload_seq = "unet" def __init__(self, unet, scheduler): super().__init__() # make sure scheduler can always be converted to DDIM scheduler = DDIMScheduler.from_config(scheduler.config) self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, eta: float = 0.0, num_inference_steps: int = 50, use_clipped_model_output: Optional[bool] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" The call function to the pipeline for generation. Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` corresponds to DDIM and `1` corresponds to DDPM. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. use_clipped_model_output (`bool`, *optional*, defaults to `None`): If `True` or `False`, see documentation for [`DDIMScheduler.step`]. If `None`, nothing is passed downstream to the scheduler (use `None` for schedulers which don't support this argument). output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Example: ```py >>> from diffusers import DDIMPipeline >>> import PIL.Image >>> import numpy as np >>> # load model and scheduler >>> pipe = DDIMPipeline.from_pretrained("fusing/ddim-lsun-bedroom") >>> # run pipeline in inference (sample random noise and denoise) >>> image = pipe(eta=0.0, num_inference_steps=50) >>> # process image to PIL >>> image_processed = image.cpu().permute(0, 2, 3, 1) >>> image_processed = (image_processed + 1.0) * 127.5 >>> image_processed = image_processed.numpy().astype(np.uint8) >>> image_pil = PIL.Image.fromarray(image_processed[0]) >>> # save image >>> image_pil.save("test.png") ``` Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ # Sample gaussian noise to begin loop if isinstance(self.unet.config.sample_size, int): image_shape = ( batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size, ) else: image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) # set step values self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(image, t).sample # 2. predict previous mean of image x_t-1 and add variance depending on eta # eta corresponds to η in paper and should be between [0, 1] # do x_t -> x_t-1 image = self.scheduler.step( model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator ).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/ddpm/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, _LazyModule, ) _import_structure = {"pipeline_ddpm": ["DDPMPipeline"]} if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_ddpm import DDPMPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) ================================================ FILE: src/diffusers/pipelines/ddpm/pipeline_ddpm.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DDPMPipeline(DiffusionPipeline): r""" Pipeline for image generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: unet ([`UNet2DModel`]): A `UNet2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of [`DDPMScheduler`], or [`DDIMScheduler`]. """ model_cpu_offload_seq = "unet" def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, num_inference_steps: int = 1000, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" The call function to the pipeline for generation. Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 1000): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Example: ```py >>> from diffusers import DDPMPipeline >>> # load model and scheduler >>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256") >>> # run pipeline in inference (sample random noise and denoise) >>> image = pipe().images[0] >>> # save image >>> image.save("ddpm_generated_image.png") ``` Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ # Sample gaussian noise to begin loop if isinstance(self.unet.config.sample_size, int): image_shape = ( batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size, ) else: image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) if self.device.type == "mps": # randn does not work reproducibly on mps image = randn_tensor(image_shape, generator=generator) image = image.to(self.device) else: image = randn_tensor(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(image, t).sample # 2. compute previous image: x_t -> x_t-1 image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/deepfloyd_if/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = { "timesteps": [ "fast27_timesteps", "smart100_timesteps", "smart185_timesteps", "smart27_timesteps", "smart50_timesteps", "super100_timesteps", "super27_timesteps", "super40_timesteps", ] } try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_if"] = ["IFPipeline"] _import_structure["pipeline_if_img2img"] = ["IFImg2ImgPipeline"] _import_structure["pipeline_if_img2img_superresolution"] = ["IFImg2ImgSuperResolutionPipeline"] _import_structure["pipeline_if_inpainting"] = ["IFInpaintingPipeline"] _import_structure["pipeline_if_inpainting_superresolution"] = ["IFInpaintingSuperResolutionPipeline"] _import_structure["pipeline_if_superresolution"] = ["IFSuperResolutionPipeline"] _import_structure["pipeline_output"] = ["IFPipelineOutput"] _import_structure["safety_checker"] = ["IFSafetyChecker"] _import_structure["watermark"] = ["IFWatermarker"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_if import IFPipeline from .pipeline_if_img2img import IFImg2ImgPipeline from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline from .pipeline_if_inpainting import IFInpaintingPipeline from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline from .pipeline_if_superresolution import IFSuperResolutionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker from .timesteps import ( fast27_timesteps, smart27_timesteps, smart50_timesteps, smart100_timesteps, smart185_timesteps, super27_timesteps, super40_timesteps, super100_timesteps, ) from .watermark import IFWatermarker else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/deepfloyd_if/pipeline_if.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) >>> pipe.enable_model_cpu_offload() >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt" ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> safety_modules = { ... "feature_extractor": pipe.feature_extractor, ... "safety_checker": pipe.safety_checker, ... "watermarker": pipe.watermarker, ... } >>> super_res_2_pipe = DiffusionPipeline.from_pretrained( ... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16 ... ) >>> super_res_2_pipe.enable_model_cpu_offload() >>> image = super_res_2_pipe( ... prompt=prompt, ... image=image, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] model_cpu_offload_seq = "text_encoder->unet" _exclude_from_cpu_offload = ["watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, num_images_per_prompt: int = 1, device: Optional[torch.device] = None, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. clean_caption (bool, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler intermediate_images = intermediate_images * self.scheduler.init_noise_sigma return intermediate_images def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, num_inference_steps: int = 100, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, height: Optional[int] = None, width: Optional[int] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size): The width in pixels of the generated image. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters height = height or self.unet.config.sample_size width = width or self.unet.config.sample_size if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(0) # 5. Prepare intermediate images intermediate_images = self.prepare_intermediate_images( batch_size * num_images_per_prompt, self.unet.config.in_channels, height, width, prompt_embeds.dtype, device, generator, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = ( torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images ) model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL image = self.numpy_to_pil(image) # 11. Apply watermark if self.watermarker is not None: image = self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image.resize((768, 512)) >>> pipe = IFImg2ImgPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "A fantasy landscape in style minecraft" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", ... text_encoder=None, ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFImg2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] model_cpu_offload_seq = "text_encoder->unet" _exclude_from_cpu_offload = ["watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, num_images_per_prompt: int = 1, device: Optional[torch.device] = None, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. clean_caption (bool, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None ): _, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) image = self.scheduler.add_noise(image, noise, timestep) return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 0.7, num_inference_steps: int = 80, timesteps: List[int] = None, guidance_scale: float = 10.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.Tensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. strength (`float`, *optional*, defaults to 0.7): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 80): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 10.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. Prepare intermediate images image = self.preprocess_image(image) image = image.to(device=device, dtype=dtype) noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = ( torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images ) model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL image = self.numpy_to_pil(image) # 11. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image.resize((768, 512)) >>> pipe = IFImg2ImgPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "A fantasy landscape in style minecraft" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", ... text_encoder=None, ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler image_noising_scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"] model_cpu_offload_seq = "text_encoder->unet" _exclude_from_cpu_offload = ["watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, image_noising_scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if unet.config.in_channels != 6: logger.warning( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, image_noising_scheduler=image_noising_scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, num_images_per_prompt: int = 1, device: Optional[torch.device] = None, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. clean_caption (bool, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, original_image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # image if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # original_image if isinstance(original_image, list): check_image_type = original_image[0] else: check_image_type = original_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`original_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(original_image, list): image_batch_size = len(original_image) elif isinstance(original_image, torch.Tensor): image_batch_size = original_image.shape[0] elif isinstance(original_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(original_image, np.ndarray): image_batch_size = original_image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError( f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: if not isinstance(image, torch.Tensor) and not isinstance(image, list): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: image = image[0] image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image, list) and isinstance(image[0], torch.Tensor): dims = image[0].ndim if dims == 3: image = torch.stack(image, dim=0) elif dims == 4: image = torch.concat(image, dim=0) else: raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") image = image.to(device=device, dtype=self.unet.dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.prepare_intermediate_images def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None ): _, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) image = self.scheduler.add_noise(image, noise, timestep) return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], original_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 0.8, prompt: Union[str, List[str]] = None, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 4.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 250, clean_caption: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image (`torch.Tensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. original_image (`torch.Tensor` or `PIL.Image.Image`): The original image that `image` was varied from. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). noise_level (`int`, *optional*, defaults to 250): The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, original_image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 device = self._execution_device # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. prepare original image original_image = self.preprocess_original_image(original_image) original_image = original_image.to(device=device, dtype=dtype) # 6. Prepare intermediate images noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( original_image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator, ) # 7. Prepare upscaled image and noise level _, _, height, width = original_image.shape image = self.preprocess_image(image, num_images_per_prompt, device) upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) if do_classifier_free_guidance: noise_level = torch.cat([noise_level] * 2) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = torch.cat([intermediate_images, upscaled], dim=1) model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 12. Convert to PIL image = self.numpy_to_pil(image) # 13. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None else: # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" >>> response = requests.get(url) >>> mask_image = Image.open(BytesIO(response.content)) >>> mask_image = mask_image >>> pipe = IFInpaintingPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "blue sunglasses" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... mask_image=mask_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... mask_image=mask_image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFInpaintingPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] model_cpu_offload_seq = "text_encoder->unet" _exclude_from_cpu_offload = ["watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, num_images_per_prompt: int = 1, device: Optional[torch.device] = None, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. clean_caption (bool, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, mask_image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # image if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # mask_image if isinstance(mask_image, list): check_image_type = mask_image[0] else: check_image_type = mask_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`mask_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(mask_image, list): image_batch_size = len(mask_image) elif isinstance(mask_image, torch.Tensor): image_batch_size = mask_image.shape[0] elif isinstance(mask_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(mask_image, np.ndarray): image_batch_size = mask_image.shape[0] else: assert False if image_batch_size != 1 and batch_size != image_batch_size: raise ValueError( f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image def preprocess_mask_image(self, mask_image) -> torch.Tensor: if not isinstance(mask_image, list): mask_image = [mask_image] if isinstance(mask_image[0], torch.Tensor): mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) if mask_image.ndim == 2: # Batch and add channel dim for single mask mask_image = mask_image.unsqueeze(0).unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] == 1: # Single mask, the 0'th dimension is considered to be # the existing batch size of 1 mask_image = mask_image.unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] != 1: # Batch of mask, the 0'th dimension is considered to be # the batching dimension mask_image = mask_image.unsqueeze(1) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 elif isinstance(mask_image[0], PIL.Image.Image): new_mask_image = [] for mask_image_ in mask_image: mask_image_ = mask_image_.convert("L") mask_image_ = resize(mask_image_, self.unet.config.sample_size) mask_image_ = np.array(mask_image_) mask_image_ = mask_image_[None, None, :] new_mask_image.append(mask_image_) mask_image = new_mask_image mask_image = np.concatenate(mask_image, axis=0) mask_image = mask_image.astype(np.float32) / 255.0 mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) elif isinstance(mask_image[0], np.ndarray): mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) return mask_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None ): image_batch_size, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) noised_image = self.scheduler.add_noise(image, noise, timestep) image = (1 - mask_image) * image + mask_image * noised_image return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, mask_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 1.0, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.Tensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. strength (`float`, *optional*, defaults to 1.0): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, mask_image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. Prepare intermediate images image = self.preprocess_image(image) image = image.to(device=device, dtype=dtype) mask_image = self.preprocess_mask_image(mask_image) mask_image = mask_image.to(device=device, dtype=dtype) if mask_image.shape[0] == 1: mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) else: mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = ( torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images ) model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 prev_intermediate_images = intermediate_images intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL image = self.numpy_to_pil(image) # 11. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" >>> response = requests.get(url) >>> mask_image = Image.open(BytesIO(response.content)) >>> mask_image = mask_image >>> pipe = IFInpaintingPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "blue sunglasses" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... mask_image=mask_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... mask_image=mask_image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler image_noising_scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}" ) # noqa model_cpu_offload_seq = "text_encoder->unet" _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] _exclude_from_cpu_offload = ["watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, image_noising_scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if unet.config.in_channels != 6: logger.warning( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, image_noising_scheduler=image_noising_scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, num_images_per_prompt: int = 1, device: Optional[torch.device] = None, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. clean_caption (bool, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, original_image, mask_image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # image if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # original_image if isinstance(original_image, list): check_image_type = original_image[0] else: check_image_type = original_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`original_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(original_image, list): image_batch_size = len(original_image) elif isinstance(original_image, torch.Tensor): image_batch_size = original_image.shape[0] elif isinstance(original_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(original_image, np.ndarray): image_batch_size = original_image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError( f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" ) # mask_image if isinstance(mask_image, list): check_image_type = mask_image[0] else: check_image_type = mask_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`mask_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(mask_image, list): image_batch_size = len(mask_image) elif isinstance(mask_image, torch.Tensor): image_batch_size = mask_image.shape[0] elif isinstance(mask_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(mask_image, np.ndarray): image_batch_size = mask_image.shape[0] else: assert False if image_batch_size != 1 and batch_size != image_batch_size: raise ValueError( f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.config.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: if not isinstance(image, torch.Tensor) and not isinstance(image, list): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: image = image[0] image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image, list) and isinstance(image[0], torch.Tensor): dims = image[0].ndim if dims == 3: image = torch.stack(image, dim=0) elif dims == 4: image = torch.concat(image, dim=0) else: raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") image = image.to(device=device, dtype=self.unet.dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) return image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.preprocess_mask_image def preprocess_mask_image(self, mask_image) -> torch.Tensor: if not isinstance(mask_image, list): mask_image = [mask_image] if isinstance(mask_image[0], torch.Tensor): mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) if mask_image.ndim == 2: # Batch and add channel dim for single mask mask_image = mask_image.unsqueeze(0).unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] == 1: # Single mask, the 0'th dimension is considered to be # the existing batch size of 1 mask_image = mask_image.unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] != 1: # Batch of mask, the 0'th dimension is considered to be # the batching dimension mask_image = mask_image.unsqueeze(1) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 elif isinstance(mask_image[0], PIL.Image.Image): new_mask_image = [] for mask_image_ in mask_image: mask_image_ = mask_image_.convert("L") mask_image_ = resize(mask_image_, self.unet.config.sample_size) mask_image_ = np.array(mask_image_) mask_image_ = mask_image_[None, None, :] new_mask_image.append(mask_image_) mask_image = new_mask_image mask_image = np.concatenate(mask_image, axis=0) mask_image = mask_image.astype(np.float32) / 255.0 mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) elif isinstance(mask_image[0], np.ndarray): mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) return mask_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.prepare_intermediate_images def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None ): image_batch_size, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) noised_image = self.scheduler.add_noise(image, noise, timestep) image = (1 - mask_image) * image + mask_image * noised_image return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], original_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, mask_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 0.8, prompt: Union[str, List[str]] = None, num_inference_steps: int = 100, timesteps: List[int] = None, guidance_scale: float = 4.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 0, clean_caption: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image (`torch.Tensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. original_image (`torch.Tensor` or `PIL.Image.Image`): The original image that `image` was varied from. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). noise_level (`int`, *optional*, defaults to 0): The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, original_image, mask_image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 device = self._execution_device # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. prepare original image original_image = self.preprocess_original_image(original_image) original_image = original_image.to(device=device, dtype=dtype) # 6. prepare mask image mask_image = self.preprocess_mask_image(mask_image) mask_image = mask_image.to(device=device, dtype=dtype) if mask_image.shape[0] == 1: mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) else: mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) # 6. Prepare intermediate images noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( original_image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator, ) # 7. Prepare upscaled image and noise level _, _, height, width = original_image.shape image = self.preprocess_image(image, num_images_per_prompt, device) upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) if do_classifier_free_guidance: noise_level = torch.cat([noise_level] * 2) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = torch.cat([intermediate_images, upscaled], dim=1) model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 prev_intermediate_images = intermediate_images intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 12. Convert to PIL image = self.numpy_to_pil(image) # 13. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None else: # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) self.maybe_free_model_hooks() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import StableDiffusionLoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) >>> pipe.enable_model_cpu_offload() >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler image_noising_scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] model_cpu_offload_seq = "text_encoder->unet" _exclude_from_cpu_offload = ["watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, image_noising_scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if unet.config.in_channels != 6: logger.warning( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, image_noising_scheduler=image_noising_scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, num_images_per_prompt: int = 1, device: Optional[torch.device] = None, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. clean_caption (bool, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, batch_size, noise_level, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: raise ValueError( f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})" ) if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler intermediate_images = intermediate_images * self.scheduler.init_noise_sigma return intermediate_images def preprocess_image(self, image, num_images_per_prompt, device): if not isinstance(image, torch.Tensor) and not isinstance(image, list): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: image = image[0] image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image, list) and isinstance(image[0], torch.Tensor): dims = image[0].ndim if dims == 3: image = torch.stack(image, dim=0) elif dims == 4: image = torch.concat(image, dim=0) else: raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") image = image.to(device=device, dtype=self.unet.dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: int = None, width: int = None, image: Union[PIL.Image.Image, np.ndarray, torch.Tensor] = None, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 4.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 250, clean_caption: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to None): The height in pixels of the generated image. width (`int`, *optional*, defaults to None): The width in pixels of the generated image. image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`): The image to be upscaled. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*, defaults to None): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). noise_level (`int`, *optional*, defaults to 250): The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, batch_size, noise_level, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters height = height or self.unet.config.sample_size width = width or self.unet.config.sample_size device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(0) # 5. Prepare intermediate images num_channels = self.unet.config.in_channels // 2 intermediate_images = self.prepare_intermediate_images( batch_size * num_images_per_prompt, num_channels, height, width, prompt_embeds.dtype, device, generator, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare upscaled image and noise level image = self.preprocess_image(image, num_images_per_prompt, device) upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) if do_classifier_free_guidance: noise_level = torch.cat([noise_level] * 2) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = torch.cat([intermediate_images, upscaled], dim=1) model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 9. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 10. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 11. Convert to PIL image = self.numpy_to_pil(image) # 12. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 9. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 10. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: src/diffusers/pipelines/deepfloyd_if/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL.Image from ...utils import BaseOutput @dataclass class IFPipelineOutput(BaseOutput): """ Args: Output class for Stable Diffusion pipelines. images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content or a watermark. `None` if safety checking could not be performed. watermark_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_detected: Optional[List[bool]] watermark_detected: Optional[List[bool]] ================================================ FILE: src/diffusers/pipelines/deepfloyd_if/safety_checker.py ================================================ import numpy as np import torch import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel from ...utils import logging logger = logging.get_logger(__name__) class IFSafetyChecker(PreTrainedModel): config_class = CLIPConfig _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPConfig): super().__init__(config) self.vision_model = CLIPVisionModelWithProjection(config.vision_config) self.p_head = nn.Linear(config.vision_config.projection_dim, 1) self.w_head = nn.Linear(config.vision_config.projection_dim, 1) @torch.no_grad() def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5): image_embeds = self.vision_model(clip_input)[0] nsfw_detected = self.p_head(image_embeds) nsfw_detected = nsfw_detected.flatten() nsfw_detected = nsfw_detected > p_threshold nsfw_detected = nsfw_detected.tolist() if any(nsfw_detected): logger.warning( "Potential NSFW content was detected in one or more images. A black image will be returned instead." " Try again with a different prompt and/or seed." ) for idx, nsfw_detected_ in enumerate(nsfw_detected): if nsfw_detected_: images[idx] = np.zeros(images[idx].shape) watermark_detected = self.w_head(image_embeds) watermark_detected = watermark_detected.flatten() watermark_detected = watermark_detected > w_threshold watermark_detected = watermark_detected.tolist() if any(watermark_detected): logger.warning( "Potential watermarked content was detected in one or more images. A black image will be returned instead." " Try again with a different prompt and/or seed." ) for idx, watermark_detected_ in enumerate(watermark_detected): if watermark_detected_: images[idx] = np.zeros(images[idx].shape) return images, nsfw_detected, watermark_detected ================================================ FILE: src/diffusers/pipelines/deepfloyd_if/timesteps.py ================================================ fast27_timesteps = [ 999, 800, 799, 600, 599, 500, 400, 399, 377, 355, 333, 311, 288, 266, 244, 222, 200, 199, 177, 155, 133, 111, 88, 66, 44, 22, 0, ] smart27_timesteps = [ 999, 976, 952, 928, 905, 882, 858, 857, 810, 762, 715, 714, 572, 429, 428, 286, 285, 238, 190, 143, 142, 118, 95, 71, 47, 24, 0, ] smart50_timesteps = [ 999, 988, 977, 966, 955, 944, 933, 922, 911, 900, 899, 879, 859, 840, 820, 800, 799, 766, 733, 700, 699, 650, 600, 599, 500, 499, 400, 399, 350, 300, 299, 266, 233, 200, 199, 179, 159, 140, 120, 100, 99, 88, 77, 66, 55, 44, 33, 22, 11, 0, ] smart100_timesteps = [ 999, 995, 992, 989, 985, 981, 978, 975, 971, 967, 964, 961, 957, 956, 951, 947, 942, 937, 933, 928, 923, 919, 914, 913, 908, 903, 897, 892, 887, 881, 876, 871, 870, 864, 858, 852, 846, 840, 834, 828, 827, 820, 813, 806, 799, 792, 785, 784, 777, 770, 763, 756, 749, 742, 741, 733, 724, 716, 707, 699, 698, 688, 677, 666, 656, 655, 645, 634, 623, 613, 612, 598, 584, 570, 569, 555, 541, 527, 526, 505, 484, 483, 462, 440, 439, 396, 395, 352, 351, 308, 307, 264, 263, 220, 219, 176, 132, 88, 44, 0, ] smart185_timesteps = [ 999, 997, 995, 992, 990, 988, 986, 984, 981, 979, 977, 975, 972, 970, 968, 966, 964, 961, 959, 957, 956, 954, 951, 949, 946, 944, 941, 939, 936, 934, 931, 929, 926, 924, 921, 919, 916, 914, 913, 910, 907, 905, 902, 899, 896, 893, 891, 888, 885, 882, 879, 877, 874, 871, 870, 867, 864, 861, 858, 855, 852, 849, 846, 843, 840, 837, 834, 831, 828, 827, 824, 821, 817, 814, 811, 808, 804, 801, 798, 795, 791, 788, 785, 784, 780, 777, 774, 770, 766, 763, 760, 756, 752, 749, 746, 742, 741, 737, 733, 730, 726, 722, 718, 714, 710, 707, 703, 699, 698, 694, 690, 685, 681, 677, 673, 669, 664, 660, 656, 655, 650, 646, 641, 636, 632, 627, 622, 618, 613, 612, 607, 602, 596, 591, 586, 580, 575, 570, 569, 563, 557, 551, 545, 539, 533, 527, 526, 519, 512, 505, 498, 491, 484, 483, 474, 466, 457, 449, 440, 439, 428, 418, 407, 396, 395, 381, 366, 352, 351, 330, 308, 307, 286, 264, 263, 242, 220, 219, 176, 175, 132, 131, 88, 44, 0, ] super27_timesteps = [ 999, 991, 982, 974, 966, 958, 950, 941, 933, 925, 916, 908, 900, 899, 874, 850, 825, 800, 799, 700, 600, 500, 400, 300, 200, 100, 0, ] super40_timesteps = [ 999, 992, 985, 978, 971, 964, 957, 949, 942, 935, 928, 921, 914, 907, 900, 899, 879, 859, 840, 820, 800, 799, 766, 733, 700, 699, 650, 600, 599, 500, 499, 400, 399, 300, 299, 200, 199, 100, 99, 0, ] super100_timesteps = [ 999, 996, 992, 989, 985, 982, 979, 975, 972, 968, 965, 961, 958, 955, 951, 948, 944, 941, 938, 934, 931, 927, 924, 920, 917, 914, 910, 907, 903, 900, 899, 891, 884, 876, 869, 861, 853, 846, 838, 830, 823, 815, 808, 800, 799, 788, 777, 766, 755, 744, 733, 722, 711, 700, 699, 688, 677, 666, 655, 644, 633, 622, 611, 600, 599, 585, 571, 557, 542, 528, 514, 500, 499, 485, 471, 457, 442, 428, 414, 400, 399, 379, 359, 340, 320, 300, 299, 279, 259, 240, 220, 200, 199, 166, 133, 100, 99, 66, 33, 0, ] ================================================ FILE: src/diffusers/pipelines/deepfloyd_if/watermark.py ================================================ from typing import List import PIL.Image import torch from PIL import Image from ...configuration_utils import ConfigMixin from ...models.modeling_utils import ModelMixin from ...utils import PIL_INTERPOLATION class IFWatermarker(ModelMixin, ConfigMixin): def __init__(self): super().__init__() self.register_buffer("watermark_image", torch.zeros((62, 62, 4))) self.watermark_image_as_pil = None def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None): # Copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 h = images[0].height w = images[0].width sample_size = sample_size or h coef = min(h / sample_size, w / sample_size) img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) S1, S2 = 1024**2, img_w * img_h K = (S2 / S1) ** 0.5 wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K) if self.watermark_image_as_pil is None: watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy() watermark_image = Image.fromarray(watermark_image, mode="RGBA") self.watermark_image_as_pil = watermark_image wm_img = self.watermark_image_as_pil.resize( (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None ) for pil_img in images: pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1]) return images ================================================ FILE: src/diffusers/pipelines/deprecated/README.md ================================================ # Deprecated Pipelines This folder contains pipelines that have very low usage as measured by model downloads, issues and PRs. While you can still use the pipelines just as before, we will stop testing the pipelines and will not accept any changes to existing files. ================================================ FILE: src/diffusers/pipelines/deprecated/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_librosa_available, is_note_seq_available, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_pt_objects _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) else: _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] _import_structure["pndm"] = ["PNDMPipeline"] _import_structure["repaint"] = ["RePaintPipeline"] _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] _import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"] try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["alt_diffusion"] = [ "AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline", "AltDiffusionPipelineOutput", ] _import_structure["versatile_diffusion"] = [ "VersatileDiffusionDualGuidedPipeline", "VersatileDiffusionImageVariationPipeline", "VersatileDiffusionPipeline", "VersatileDiffusionTextToImagePipeline", ] _import_structure["vq_diffusion"] = ["VQDiffusionPipeline"] _import_structure["stable_diffusion_variants"] = [ "CycleDiffusionPipeline", "StableDiffusionInpaintPipelineLegacy", "StableDiffusionPix2PixZeroPipeline", "StableDiffusionParadigmsPipeline", "StableDiffusionModelEditingPipeline", ] try: if not (is_torch_available() and is_librosa_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_librosa_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) else: _import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"] try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) else: _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_pt_objects import * else: from .latent_diffusion_uncond import LDMPipeline from .pndm import PNDMPipeline from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline from .stochastic_karras_ve import KarrasVePipeline try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AltDiffusionPipelineOutput from .audio_diffusion import AudioDiffusionPipeline, Mel from .spectrogram_diffusion import SpectrogramDiffusionPipeline from .stable_diffusion_variants import ( CycleDiffusionPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionModelEditingPipeline, StableDiffusionParadigmsPipeline, StableDiffusionPix2PixZeroPipeline, ) from .stochastic_karras_ve import KarrasVePipeline from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline try: if not (is_torch_available() and is_librosa_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_librosa_objects import * else: from .audio_diffusion import AudioDiffusionPipeline, Mel try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 else: from .spectrogram_diffusion import ( MidiProcessor, SpectrogramDiffusionPipeline, ) else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/deprecated/alt_diffusion/__init__.py ================================================ from typing import TYPE_CHECKING from ....utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils import dummy_torch_and_transformers_objects _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["modeling_roberta_series"] = ["RobertaSeriesModelWithTransformation"] _import_structure["pipeline_alt_diffusion"] = ["AltDiffusionPipeline"] _import_structure["pipeline_alt_diffusion_img2img"] = ["AltDiffusionImg2ImgPipeline"] _import_structure["pipeline_output"] = ["AltDiffusionPipelineOutput"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils.dummy_torch_and_transformers_objects import * else: from .modeling_roberta_series import RobertaSeriesModelWithTransformation from .pipeline_alt_diffusion import AltDiffusionPipeline from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline from .pipeline_output import AltDiffusionPipelineOutput else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py ================================================ from dataclasses import dataclass from typing import Optional, Tuple import torch from torch import nn from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel from transformers.utils import ModelOutput @dataclass class TransformationModelOutput(ModelOutput): """ Base class for text model's outputs that also contains a pooling of the last hidden states. Args: text_embeds (`torch.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The text embeddings obtained by applying the projection layer to the pooler_output. last_hidden_state (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ projection_state: Optional[torch.Tensor] = None last_hidden_state: torch.Tensor = None hidden_states: Optional[Tuple[torch.Tensor]] = None attentions: Optional[Tuple[torch.Tensor]] = None class RobertaSeriesConfig(XLMRobertaConfig): def __init__( self, pad_token_id=1, bos_token_id=0, eos_token_id=2, project_dim=512, pooler_fn="cls", learn_encoder=False, use_attention_mask=True, **kwargs, ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.project_dim = project_dim self.pooler_fn = pooler_fn self.learn_encoder = learn_encoder self.use_attention_mask = use_attention_mask class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] base_model_prefix = "roberta" config_class = RobertaSeriesConfig def __init__(self, config): super().__init__(config) self.roberta = XLMRobertaModel(config) self.transformation = nn.Linear(config.hidden_size, config.project_dim) self.has_pre_transformation = getattr(config, "has_pre_transformation", False) if self.has_pre_transformation: self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ): r""" """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.base_model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=True if self.has_pre_transformation else output_hidden_states, return_dict=return_dict, ) if self.has_pre_transformation: sequence_output2 = outputs["hidden_states"][-2] sequence_output2 = self.pre_LN(sequence_output2) projection_state2 = self.transformation_pre(sequence_output2) return TransformationModelOutput( projection_state=projection_state2, last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) else: projection_state = self.transformation(outputs.last_hidden_state) return TransformationModelOutput( projection_state=projection_state, last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) ================================================ FILE: src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from packaging import version from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer from ....configuration_utils import FrozenDict from ....image_processor import PipelineImageInput, VaeImageProcessor from ....loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin, ) from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ....models.lora import adjust_lora_scale_text_encoder from ....schedulers import KarrasDiffusionSchedulers from ....utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .modeling_roberta_series import RobertaSeriesModelWithTransformation from .pipeline_output import AltDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AltDiffusionPipeline >>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap" >>> prompt = "黑暗精灵公主,非常详细,幻想,非常详细,数字绘画,概念艺术,敏锐的焦点,插图" >>> image = pipe(prompt).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class AltDiffusionPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Alt Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.RobertaSeriesModelWithTransformation`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.XLMRobertaTokenizer`]): A `XLMRobertaTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: RobertaSeriesModelWithTransformation, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: timesteps (`torch.Tensor`): generate embedding vectors at these timesteps embedding_dim (`int`, *optional*, defaults to 512): dimension of the embeddings to generate dtype: data type of the generated embeddings Returns: `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # to deal with lora scaling and other possible forward hooks # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, num_images_per_prompt, output_hidden_state ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from packaging import version from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer from ....configuration_utils import FrozenDict from ....image_processor import PipelineImageInput, VaeImageProcessor from ....loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin, ) from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ....models.lora import adjust_lora_scale_text_encoder from ....schedulers import KarrasDiffusionSchedulers from ....utils import ( PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .modeling_roberta_series import RobertaSeriesModelWithTransformation from .pipeline_output import AltDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import AltDiffusionImg2ImgPipeline >>> device = "cuda" >>> model_id_or_path = "BAAI/AltDiffusion-m9" >>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") >>> init_image = init_image.resize((768, 512)) >>> # "A fantasy landscape, trending on artstation" >>> prompt = "幻想风景, artstation" >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images >>> images[0].save("幻想风景.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class AltDiffusionImg2ImgPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, ): r""" Pipeline for text-guided image-to-image generation using Alt Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.RobertaSeriesModelWithTransformation`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.XLMRobertaTokenizer`]): A `XLMRobertaTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: RobertaSeriesModelWithTransformation, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: timesteps (`torch.Tensor`): generate embedding vectors at these timesteps embedding_dim (`int`, *optional*, defaults to 512): dimension of the embeddings to generate dtype: data type of the generated embeddings Returns: `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, num_images_per_prompt, output_hidden_state ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. set timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL.Image from ....utils import ( BaseOutput, ) @dataclass # Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt class AltDiffusionPipelineOutput(BaseOutput): """ Output class for Alt Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. nsfw_content_detected (`List[bool]`) List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] ================================================ FILE: src/diffusers/pipelines/deprecated/audio_diffusion/__init__.py ================================================ from typing import TYPE_CHECKING from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule _import_structure = { "mel": ["Mel"], "pipeline_audio_diffusion": ["AudioDiffusionPipeline"], } if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .mel import Mel from .pipeline_audio_diffusion import AudioDiffusionPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) ================================================ FILE: src/diffusers/pipelines/deprecated/audio_diffusion/mel.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np # noqa: E402 from ....configuration_utils import ConfigMixin, register_to_config from ....schedulers.scheduling_utils import SchedulerMixin try: import librosa # noqa: E402 _librosa_can_be_imported = True _import_error = "" except Exception as e: _librosa_can_be_imported = False _import_error = ( f"Cannot import librosa because {e}. Make sure to correctly install librosa to be able to install it." ) from PIL import Image # noqa: E402 class Mel(ConfigMixin, SchedulerMixin): """ Parameters: x_res (`int`): x resolution of spectrogram (time). y_res (`int`): y resolution of spectrogram (frequency bins). sample_rate (`int`): Sample rate of audio. n_fft (`int`): Number of Fast Fourier Transforms. hop_length (`int`): Hop length (a higher number is recommended if `y_res` < 256). top_db (`int`): Loudest decibel value. n_iter (`int`): Number of iterations for Griffin-Lim Mel inversion. """ config_name = "mel_config.json" @register_to_config def __init__( self, x_res: int = 256, y_res: int = 256, sample_rate: int = 22050, n_fft: int = 2048, hop_length: int = 512, top_db: int = 80, n_iter: int = 32, ): self.hop_length = hop_length self.sr = sample_rate self.n_fft = n_fft self.top_db = top_db self.n_iter = n_iter self.set_resolution(x_res, y_res) self.audio = None if not _librosa_can_be_imported: raise ValueError(_import_error) def set_resolution(self, x_res: int, y_res: int): """Set resolution. Args: x_res (`int`): x resolution of spectrogram (time). y_res (`int`): y resolution of spectrogram (frequency bins). """ self.x_res = x_res self.y_res = y_res self.n_mels = self.y_res self.slice_size = self.x_res * self.hop_length - 1 def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None): """Load audio. Args: audio_file (`str`): An audio file that must be on disk due to [Librosa](https://librosa.org/) limitation. raw_audio (`np.ndarray`): The raw audio file as a NumPy array. """ if audio_file is not None: self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr) else: self.audio = raw_audio # Pad with silence if necessary. if len(self.audio) < self.x_res * self.hop_length: self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))]) def get_number_of_slices(self) -> int: """Get number of slices in audio. Returns: `int`: Number of spectograms audio can be sliced into. """ return len(self.audio) // self.slice_size def get_audio_slice(self, slice: int = 0) -> np.ndarray: """Get slice of audio. Args: slice (`int`): Slice number of audio (out of `get_number_of_slices()`). Returns: `np.ndarray`: The audio slice as a NumPy array. """ return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)] def get_sample_rate(self) -> int: """Get sample rate. Returns: `int`: Sample rate of audio. """ return self.sr def audio_slice_to_image(self, slice: int) -> Image.Image: """Convert slice of audio to spectrogram. Args: slice (`int`): Slice number of audio to convert (out of `get_number_of_slices()`). Returns: `PIL Image`: A grayscale image of `x_res x y_res`. """ S = librosa.feature.melspectrogram( y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels ) log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db) bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8) image = Image.fromarray(bytedata) return image def image_to_audio(self, image: Image.Image) -> np.ndarray: """Converts spectrogram to audio. Args: image (`PIL Image`): An grayscale image of `x_res x y_res`. Returns: audio (`np.ndarray`): The audio as a NumPy array. """ bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width)) log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db S = librosa.db_to_power(log_S) audio = librosa.feature.inverse.mel_to_audio( S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter ) return audio ================================================ FILE: src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from math import acos, sin from typing import List, Tuple, Union import numpy as np import torch from PIL import Image from ....models import AutoencoderKL, UNet2DConditionModel from ....schedulers import DDIMScheduler, DDPMScheduler from ....utils.torch_utils import randn_tensor from ...pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput from .mel import Mel class AudioDiffusionPipeline(DiffusionPipeline): """ Pipeline for audio diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: vqae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. mel ([`Mel`]): Transform audio into a spectrogram. scheduler ([`DDIMScheduler`] or [`DDPMScheduler`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`] or [`DDPMScheduler`]. """ _optional_components = ["vqvae"] def __init__( self, vqvae: AutoencoderKL, unet: UNet2DConditionModel, mel: Mel, scheduler: Union[DDIMScheduler, DDPMScheduler], ): super().__init__() self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae) def get_default_steps(self) -> int: """Returns default number of steps recommended for inference. Returns: `int`: The number of steps. """ return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000 @torch.no_grad() def __call__( self, batch_size: int = 1, audio_file: str = None, raw_audio: np.ndarray = None, slice: int = 0, start_step: int = 0, steps: int = None, generator: torch.Generator = None, mask_start_secs: float = 0, mask_end_secs: float = 0, step_generator: torch.Generator = None, eta: float = 0, noise: torch.Tensor = None, encoding: torch.Tensor = None, return_dict=True, ) -> Union[ Union[AudioPipelineOutput, ImagePipelineOutput], Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]], ]: """ The call function to the pipeline for generation. Args: batch_size (`int`): Number of samples to generate. audio_file (`str`): An audio file that must be on disk due to [Librosa](https://librosa.org/) limitation. raw_audio (`np.ndarray`): The raw audio file as a NumPy array. slice (`int`): Slice number of audio to convert. start_step (int): Step to start diffusion from. steps (`int`): Number of denoising steps (defaults to `50` for DDIM and `1000` for DDPM). generator (`torch.Generator`): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. mask_start_secs (`float`): Number of seconds of audio to mask (not generate) at start. mask_end_secs (`float`): Number of seconds of audio to mask (not generate) at end. step_generator (`torch.Generator`): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) used to denoise. None eta (`float`): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. noise (`torch.Tensor`): A noise tensor of shape `(batch_size, 1, height, width)` or `None`. encoding (`torch.Tensor`): A tensor for [`UNet2DConditionModel`] of shape `(batch_size, seq_length, cross_attention_dim)`. return_dict (`bool`): Whether or not to return a [`AudioPipelineOutput`], [`ImagePipelineOutput`] or a plain tuple. Examples: For audio diffusion: ```py import torch from IPython.display import Audio from diffusers import DiffusionPipeline device = "cuda" if torch.cuda.is_available() else "cpu" pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256").to(device) output = pipe() display(output.images[0]) display(Audio(output.audios[0], rate=mel.get_sample_rate())) ``` For latent audio diffusion: ```py import torch from IPython.display import Audio from diffusers import DiffusionPipeline device = "cuda" if torch.cuda.is_available() else "cpu" pipe = DiffusionPipeline.from_pretrained("teticio/latent-audio-diffusion-256").to(device) output = pipe() display(output.images[0]) display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate())) ``` For other tasks like variation, inpainting, outpainting, etc: ```py output = pipe( raw_audio=output.audios[0, 0], start_step=int(pipe.get_default_steps() / 2), mask_start_secs=1, mask_end_secs=1, ) display(output.images[0]) display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate())) ``` Returns: `List[PIL Image]`: A list of Mel spectrograms (`float`, `List[np.ndarray]`) with the sample rate and raw audio. """ steps = steps or self.get_default_steps() self.scheduler.set_timesteps(steps) step_generator = step_generator or generator # For backwards compatibility if isinstance(self.unet.config.sample_size, int): self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size) if noise is None: noise = randn_tensor( ( batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1], ), generator=generator, device=self.device, ) images = noise mask = None if audio_file is not None or raw_audio is not None: self.mel.load_audio(audio_file, raw_audio) input_image = self.mel.audio_slice_to_image(slice) input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape( (input_image.height, input_image.width) ) input_image = (input_image / 255) * 2 - 1 input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device) if self.vqvae is not None: input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample( generator=generator )[0] input_images = self.vqvae.config.scaling_factor * input_images if start_step > 0: images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) pixels_per_second = ( self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length ) mask_start = int(mask_start_secs * pixels_per_second) mask_end = int(mask_end_secs * pixels_per_second) mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:])) for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])): if isinstance(self.unet, UNet2DConditionModel): model_output = self.unet(images, t, encoding)["sample"] else: model_output = self.unet(images, t)["sample"] if isinstance(self.scheduler, DDIMScheduler): images = self.scheduler.step( model_output=model_output, timestep=t, sample=images, eta=eta, generator=step_generator, )["prev_sample"] else: images = self.scheduler.step( model_output=model_output, timestep=t, sample=images, generator=step_generator, )["prev_sample"] if mask is not None: if mask_start > 0: images[:, :, :, :mask_start] = mask[:, step, :, :mask_start] if mask_end > 0: images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:] if self.vqvae is not None: # 0.18215 was scaling factor used in training to ensure unit variance images = 1 / self.vqvae.config.scaling_factor * images images = self.vqvae.decode(images)["sample"] images = (images / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).numpy() images = (images * 255).round().astype("uint8") images = list( (Image.fromarray(_[:, :, 0]) for _ in images) if images.shape[3] == 1 else (Image.fromarray(_, mode="RGB").convert("L") for _ in images) ) audios = [self.mel.image_to_audio(_) for _ in images] if not return_dict: return images, (self.mel.get_sample_rate(), audios) return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images)) @torch.no_grad() def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray: """ Reverse the denoising step process to recover a noisy image from the generated image. Args: images (`List[PIL Image]`): List of images to encode. steps (`int`): Number of encoding steps to perform (defaults to `50`). Returns: `np.ndarray`: A noise tensor of shape `(batch_size, 1, height, width)`. """ # Only works with DDIM as this method is deterministic assert isinstance(self.scheduler, DDIMScheduler) self.scheduler.set_timesteps(steps) sample = np.array( [np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images] ) sample = (sample / 255) * 2 - 1 sample = torch.Tensor(sample).to(self.device) for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))): prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t model_output = self.unet(sample, t)["sample"] pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5) sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output return sample @staticmethod def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor: """Spherical Linear intERPolation. Args: x0 (`torch.Tensor`): The first tensor to interpolate between. x1 (`torch.Tensor`): Second tensor to interpolate between. alpha (`float`): Interpolation between 0 and 1 Returns: `torch.Tensor`: The interpolated tensor. """ theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1)) return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta) ================================================ FILE: src/diffusers/pipelines/deprecated/latent_diffusion_uncond/__init__.py ================================================ from typing import TYPE_CHECKING from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule _import_structure = {"pipeline_latent_diffusion_uncond": ["LDMPipeline"]} if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_latent_diffusion_uncond import LDMPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) ================================================ FILE: src/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Tuple, Union import torch from ....models import UNet2DModel, VQModel from ....schedulers import DDIMScheduler from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput class LDMPipeline(DiffusionPipeline): r""" Pipeline for unconditional image generation using latent diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) model to encode and decode images to and from latent representations. unet ([`UNet2DModel`]): A `UNet2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): [`DDIMScheduler`] is used in combination with `unet` to denoise the encoded image latents. """ def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): super().__init__() self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, eta: float = 0.0, num_inference_steps: int = 50, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: r""" The call function to the pipeline for generation. Args: batch_size (`int`, *optional*, defaults to 1): Number of images to generate. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Example: ```py >>> from diffusers import LDMPipeline >>> # load model and scheduler >>> pipe = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256") >>> # run pipeline in inference (sample random noise and denoise) >>> image = pipe().images[0] ``` Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ latents = randn_tensor( (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, ) latents = latents.to(self.device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma self.scheduler.set_timesteps(num_inference_steps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_kwargs = {} if accepts_eta: extra_kwargs["eta"] = eta for t in self.progress_bar(self.scheduler.timesteps): latent_model_input = self.scheduler.scale_model_input(latents, t) # predict the noise residual noise_prediction = self.unet(latent_model_input, t).sample # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample # adjust latents with inverse of vae scale latents = latents / self.vqvae.config.scaling_factor # decode the image latents with the VAE image = self.vqvae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/deprecated/pndm/__init__.py ================================================ from typing import TYPE_CHECKING from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule _import_structure = {"pipeline_pndm": ["PNDMPipeline"]} if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_pndm import PNDMPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) ================================================ FILE: src/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ....models import UNet2DModel from ....schedulers import PNDMScheduler from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput class PNDMPipeline(DiffusionPipeline): r""" Pipeline for unconditional image generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: unet ([`UNet2DModel`]): A `UNet2DModel` to denoise the encoded image latents. scheduler ([`PNDMScheduler`]): A `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. """ unet: UNet2DModel scheduler: PNDMScheduler def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): super().__init__() scheduler = PNDMScheduler.from_config(scheduler.config) self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 50, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" The call function to the pipeline for generation. Args: batch_size (`int`, `optional`, defaults to 1): The number of images to generate. num_inference_steps (`int`, `optional`, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator`, `optional`): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. Example: ```py >>> from diffusers import PNDMPipeline >>> # load model and scheduler >>> pndm = PNDMPipeline.from_pretrained("google/ddpm-cifar10-32") >>> # run pipeline in inference (sample random noise and denoise) >>> image = pndm().images[0] >>> # save image >>> image.save("pndm_generated_image.png") ``` Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ # For more information on the sampling method you can take a look at Algorithm 2 of # the official paper: https://arxiv.org/pdf/2202.09778.pdf # Sample gaussian noise to begin loop image = randn_tensor( (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, device=self.device, ) self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): model_output = self.unet(image, t).sample image = self.scheduler.step(model_output, t, image).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/deprecated/repaint/__init__.py ================================================ from typing import TYPE_CHECKING from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule _import_structure = {"pipeline_repaint": ["RePaintPipeline"]} if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_repaint import RePaintPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) ================================================ FILE: src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py ================================================ # Copyright 2024 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import numpy as np import PIL.Image import torch from ....models import UNet2DModel from ....schedulers import RePaintScheduler from ....utils import PIL_INTERPOLATION, deprecate, logging from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]): if isinstance(mask, torch.Tensor): return mask elif isinstance(mask, PIL.Image.Image): mask = [mask] if isinstance(mask[0], PIL.Image.Image): w, h = mask[0].size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] mask = np.concatenate(mask, axis=0) mask = mask.astype(np.float32) / 255.0 mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) elif isinstance(mask[0], torch.Tensor): mask = torch.cat(mask, dim=0) return mask class RePaintPipeline(DiffusionPipeline): r""" Pipeline for image inpainting using RePaint. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: unet ([`UNet2DModel`]): A `UNet2DModel` to denoise the encoded image latents. scheduler ([`RePaintScheduler`]): A `RePaintScheduler` to be used in combination with `unet` to denoise the encoded image. """ unet: UNet2DModel scheduler: RePaintScheduler model_cpu_offload_seq = "unet" def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, image: Union[torch.Tensor, PIL.Image.Image], mask_image: Union[torch.Tensor, PIL.Image.Image], num_inference_steps: int = 250, eta: float = 0.0, jump_length: int = 10, jump_n_sample: int = 10, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" The call function to the pipeline for generation. Args: image (`torch.Tensor` or `PIL.Image.Image`): The original image to inpaint on. mask_image (`torch.Tensor` or `PIL.Image.Image`): The mask_image where 0.0 define which part of the original image to inpaint. num_inference_steps (`int`, *optional*, defaults to 1000): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. eta (`float`): The weight of the added noise in a diffusion step. Its value is between 0.0 and 1.0; 0.0 corresponds to DDIM and 1.0 is the DDPM scheduler. jump_length (`int`, *optional*, defaults to 10): The number of steps taken forward in time before going backward in time for a single jump ("j" in RePaint paper). Take a look at Figure 9 and 10 in the [paper](https://arxiv.org/pdf/2201.09865.pdf). jump_n_sample (`int`, *optional*, defaults to 10): The number of times to make a forward time jump for a given chosen time sample. Take a look at Figure 9 and 10 in the [paper](https://arxiv.org/pdf/2201.09865.pdf). generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. Example: ```py >>> from io import BytesIO >>> import torch >>> import PIL >>> import requests >>> from diffusers import RePaintPipeline, RePaintScheduler >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png" >>> mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png" >>> # Load the original image and the mask as PIL images >>> original_image = download_image(img_url).resize((256, 256)) >>> mask_image = download_image(mask_url).resize((256, 256)) >>> # Load the RePaint scheduler and pipeline based on a pretrained DDPM model >>> scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256") >>> pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> output = pipe( ... image=original_image, ... mask_image=mask_image, ... num_inference_steps=250, ... eta=0.0, ... jump_length=10, ... jump_n_sample=10, ... generator=generator, ... ) >>> inpainted_image = output.images[0] ``` Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ original_image = image original_image = _preprocess_image(original_image) original_image = original_image.to(device=self._execution_device, dtype=self.unet.dtype) mask_image = _preprocess_mask(mask_image) mask_image = mask_image.to(device=self._execution_device, dtype=self.unet.dtype) batch_size = original_image.shape[0] # sample gaussian noise to begin the loop if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) image_shape = original_image.shape image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) # set step values self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self._execution_device) self.scheduler.eta = eta t_last = self.scheduler.timesteps[0] + 1 generator = generator[0] if isinstance(generator, list) else generator for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): if t < t_last: # predict the noise residual model_output = self.unet(image, t).sample # compute previous image: x_t -> x_t-1 image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample else: # compute the reverse: x_t-1 -> x_t image = self.scheduler.undo_step(image, t_last, generator) t_last = t image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/deprecated/score_sde_ve/__init__.py ================================================ from typing import TYPE_CHECKING from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule _import_structure = {"pipeline_score_sde_ve": ["ScoreSdeVePipeline"]} if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_score_sde_ve import ScoreSdeVePipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) ================================================ FILE: src/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ....models import UNet2DModel from ....schedulers import ScoreSdeVeScheduler from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput class ScoreSdeVePipeline(DiffusionPipeline): r""" Pipeline for unconditional image generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: unet ([`UNet2DModel`]): A `UNet2DModel` to denoise the encoded image. scheduler ([`ScoreSdeVeScheduler`]): A `ScoreSdeVeScheduler` to be used in combination with `unet` to denoise the encoded image. """ unet: UNet2DModel scheduler: ScoreSdeVeScheduler def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 2000, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" The call function to the pipeline for generation. Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, `optional`): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ img_size = self.unet.config.sample_size shape = (batch_size, 3, img_size, img_size) model = self.unet sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma sample = sample.to(self.device) self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps) for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) # correction step for _ in range(self.scheduler.config.correct_steps): model_output = self.unet(sample, sigma_t).sample sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample # prediction step model_output = model(sample, sigma_t).sample output = self.scheduler.step_pred(model_output, t, sample, generator=generator) sample, sample_mean = output.prev_sample, output.prev_sample_mean sample = sample_mean.clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": sample = self.numpy_to_pil(sample) if not return_dict: return (sample,) return ImagePipelineOutput(images=sample) ================================================ FILE: src/diffusers/pipelines/deprecated/spectrogram_diffusion/__init__.py ================================================ # flake8: noqa from typing import TYPE_CHECKING from ....utils import ( DIFFUSERS_SLOW_IMPORT, _LazyModule, is_note_seq_available, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, get_objects_from_module, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["continous_encoder"] = ["SpectrogramContEncoder"] _import_structure["notes_encoder"] = ["SpectrogramNotesEncoder"] _import_structure["pipeline_spectrogram_diffusion"] = [ "SpectrogramContEncoder", "SpectrogramDiffusionPipeline", "T5FilmDecoder", ] try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils import dummy_transformers_and_torch_and_note_seq_objects _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) else: _import_structure["midi_utils"] = ["MidiProcessor"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils.dummy_torch_and_transformers_objects import * else: from .pipeline_spectrogram_diffusion import SpectrogramDiffusionPipeline from .pipeline_spectrogram_diffusion import SpectrogramContEncoder from .pipeline_spectrogram_diffusion import SpectrogramNotesEncoder from .pipeline_spectrogram_diffusion import T5FilmDecoder try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils.dummy_transformers_and_torch_and_note_seq_objects import * else: from .midi_utils import MidiProcessor else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from transformers.modeling_utils import ModuleUtilsMixin from transformers.models.t5.modeling_t5 import ( T5Block, T5Config, T5LayerNorm, ) from ....configuration_utils import ConfigMixin, register_to_config from ....models import ModelMixin class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): @register_to_config def __init__( self, input_dims: int, targets_context_length: int, d_model: int, dropout_rate: float, num_layers: int, num_heads: int, d_kv: int, d_ff: int, feed_forward_proj: str, is_decoder: bool = False, ): super().__init__() self.input_proj = nn.Linear(input_dims, d_model, bias=False) self.position_encoding = nn.Embedding(targets_context_length, d_model) self.position_encoding.weight.requires_grad = False self.dropout_pre = nn.Dropout(p=dropout_rate) t5config = T5Config( d_model=d_model, num_heads=num_heads, d_kv=d_kv, d_ff=d_ff, feed_forward_proj=feed_forward_proj, dropout_rate=dropout_rate, is_decoder=is_decoder, is_encoder_decoder=False, ) self.encoders = nn.ModuleList() for lyr_num in range(num_layers): lyr = T5Block(t5config) self.encoders.append(lyr) self.layer_norm = T5LayerNorm(d_model) self.dropout_post = nn.Dropout(p=dropout_rate) def forward(self, encoder_inputs, encoder_inputs_mask): x = self.input_proj(encoder_inputs) # terminal relative positional encodings max_positions = encoder_inputs.shape[1] input_positions = torch.arange(max_positions, device=encoder_inputs.device) seq_lens = encoder_inputs_mask.sum(-1) input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0) x += self.position_encoding(input_positions) x = self.dropout_pre(x) # inverted the attention mask input_shape = encoder_inputs.size() extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) for lyr in self.encoders: x = lyr(x, extended_attention_mask)[0] x = self.layer_norm(x) return self.dropout_post(x), encoder_inputs_mask ================================================ FILE: src/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import math import os from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn.functional as F from ....utils import is_note_seq_available from .pipeline_spectrogram_diffusion import TARGET_FEATURE_LENGTH if is_note_seq_available(): import note_seq else: raise ImportError("Please install note-seq via `pip install note-seq`") INPUT_FEATURE_LENGTH = 2048 SAMPLE_RATE = 16000 HOP_SIZE = 320 FRAME_RATE = int(SAMPLE_RATE // HOP_SIZE) DEFAULT_STEPS_PER_SECOND = 100 DEFAULT_MAX_SHIFT_SECONDS = 10 DEFAULT_NUM_VELOCITY_BINS = 1 SLAKH_CLASS_PROGRAMS = { "Acoustic Piano": 0, "Electric Piano": 4, "Chromatic Percussion": 8, "Organ": 16, "Acoustic Guitar": 24, "Clean Electric Guitar": 26, "Distorted Electric Guitar": 29, "Acoustic Bass": 32, "Electric Bass": 33, "Violin": 40, "Viola": 41, "Cello": 42, "Contrabass": 43, "Orchestral Harp": 46, "Timpani": 47, "String Ensemble": 48, "Synth Strings": 50, "Choir and Voice": 52, "Orchestral Hit": 55, "Trumpet": 56, "Trombone": 57, "Tuba": 58, "French Horn": 60, "Brass Section": 61, "Soprano/Alto Sax": 64, "Tenor Sax": 66, "Baritone Sax": 67, "Oboe": 68, "English Horn": 69, "Bassoon": 70, "Clarinet": 71, "Pipe": 73, "Synth Lead": 80, "Synth Pad": 88, } @dataclasses.dataclass class NoteRepresentationConfig: """Configuration note representations.""" onsets_only: bool include_ties: bool @dataclasses.dataclass class NoteEventData: pitch: int velocity: Optional[int] = None program: Optional[int] = None is_drum: Optional[bool] = None instrument: Optional[int] = None @dataclasses.dataclass class NoteEncodingState: """Encoding state for note transcription, keeping track of active pitches.""" # velocity bin for active pitches and programs active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(default_factory=dict) @dataclasses.dataclass class EventRange: type: str min_value: int max_value: int @dataclasses.dataclass class Event: type: str value: int class Tokenizer: def __init__(self, regular_ids: int): # The special tokens: 0=PAD, 1=EOS, and 2=UNK self._num_special_tokens = 3 self._num_regular_tokens = regular_ids def encode(self, token_ids): encoded = [] for token_id in token_ids: if not 0 <= token_id < self._num_regular_tokens: raise ValueError( f"token_id {token_id} does not fall within valid range of [0, {self._num_regular_tokens})" ) encoded.append(token_id + self._num_special_tokens) # Add EOS token encoded.append(1) # Pad to till INPUT_FEATURE_LENGTH encoded = encoded + [0] * (INPUT_FEATURE_LENGTH - len(encoded)) return encoded class Codec: """Encode and decode events. Useful for declaring what certain ranges of a vocabulary should be used for. This is intended to be used from Python before encoding or after decoding with GenericTokenVocabulary. This class is more lightweight and does not include things like EOS or UNK token handling. To ensure that 'shift' events are always the first block of the vocab and start at 0, that event type is required and specified separately. """ def __init__(self, max_shift_steps: int, steps_per_second: float, event_ranges: List[EventRange]): """Define Codec. Args: max_shift_steps: Maximum number of shift steps that can be encoded. steps_per_second: Shift steps will be interpreted as having a duration of 1 / steps_per_second. event_ranges: Other supported event types and their ranges. """ self.steps_per_second = steps_per_second self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps) self._event_ranges = [self._shift_range] + event_ranges # Ensure all event types have unique names. assert len(self._event_ranges) == len({er.type for er in self._event_ranges}) @property def num_classes(self) -> int: return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) # The next couple methods are simplified special case methods just for shift # events that are intended to be used from within autograph functions. def is_shift_event_index(self, index: int) -> bool: return (self._shift_range.min_value <= index) and (index <= self._shift_range.max_value) @property def max_shift_steps(self) -> int: return self._shift_range.max_value def encode_event(self, event: Event) -> int: """Encode an event to an index.""" offset = 0 for er in self._event_ranges: if event.type == er.type: if not er.min_value <= event.value <= er.max_value: raise ValueError( f"Event value {event.value} is not within valid range " f"[{er.min_value}, {er.max_value}] for type {event.type}" ) return offset + event.value - er.min_value offset += er.max_value - er.min_value + 1 raise ValueError(f"Unknown event type: {event.type}") def event_type_range(self, event_type: str) -> Tuple[int, int]: """Return [min_id, max_id] for an event type.""" offset = 0 for er in self._event_ranges: if event_type == er.type: return offset, offset + (er.max_value - er.min_value) offset += er.max_value - er.min_value + 1 raise ValueError(f"Unknown event type: {event_type}") def decode_event_index(self, index: int) -> Event: """Decode an event index to an Event.""" offset = 0 for er in self._event_ranges: if offset <= index <= offset + er.max_value - er.min_value: return Event(type=er.type, value=er.min_value + index - offset) offset += er.max_value - er.min_value + 1 raise ValueError(f"Unknown event index: {index}") @dataclasses.dataclass class ProgramGranularity: # both tokens_map_fn and program_map_fn should be idempotent tokens_map_fn: Callable[[Sequence[int], Codec], Sequence[int]] program_map_fn: Callable[[int], int] def drop_programs(tokens, codec: Codec): """Drops program change events from a token sequence.""" min_program_id, max_program_id = codec.event_type_range("program") return tokens[(tokens < min_program_id) | (tokens > max_program_id)] def programs_to_midi_classes(tokens, codec): """Modifies program events to be the first program in the MIDI class.""" min_program_id, max_program_id = codec.event_type_range("program") is_program = (tokens >= min_program_id) & (tokens <= max_program_id) return np.where(is_program, min_program_id + 8 * ((tokens - min_program_id) // 8), tokens) PROGRAM_GRANULARITIES = { # "flat" granularity; drop program change tokens and set NoteSequence # programs to zero "flat": ProgramGranularity(tokens_map_fn=drop_programs, program_map_fn=lambda program: 0), # map each program to the first program in its MIDI class "midi_class": ProgramGranularity( tokens_map_fn=programs_to_midi_classes, program_map_fn=lambda program: 8 * (program // 8) ), # leave programs as is "full": ProgramGranularity(tokens_map_fn=lambda tokens, codec: tokens, program_map_fn=lambda program: program), } def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1): """ equivalent of tf.signal.frame """ signal_length = signal.shape[axis] if pad_end: frames_overlap = frame_length - frame_step rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap) pad_size = int(frame_length - rest_samples) if pad_size != 0: pad_axis = [0] * signal.ndim pad_axis[axis] = pad_size signal = F.pad(signal, pad_axis, "constant", pad_value) frames = signal.unfold(axis, frame_length, frame_step) return frames def program_to_slakh_program(program): # this is done very hackily, probably should use a custom mapping for slakh_program in sorted(SLAKH_CLASS_PROGRAMS.values(), reverse=True): if program >= slakh_program: return slakh_program def audio_to_frames( samples, hop_size: int, frame_rate: int, ) -> Tuple[Sequence[Sequence[int]], torch.Tensor]: """Convert audio samples to non-overlapping frames and frame times.""" frame_size = hop_size samples = np.pad(samples, [0, frame_size - len(samples) % frame_size], mode="constant") # Split audio into frames. frames = frame( torch.Tensor(samples).unsqueeze(0), frame_length=frame_size, frame_step=frame_size, pad_end=False, # TODO check why its off by 1 here when True ) num_frames = len(samples) // frame_size times = np.arange(num_frames) / frame_rate return frames, times def note_sequence_to_onsets_and_offsets_and_programs( ns: note_seq.NoteSequence, ) -> Tuple[Sequence[float], Sequence[NoteEventData]]: """Extract onset & offset times and pitches & programs from a NoteSequence. The onset & offset times will not necessarily be in sorted order. Args: ns: NoteSequence from which to extract onsets and offsets. Returns: times: A list of note onset and offset times. values: A list of NoteEventData objects where velocity is zero for note offsets. """ # Sort by program and pitch and put offsets before onsets as a tiebreaker for # subsequent stable sort. notes = sorted(ns.notes, key=lambda note: (note.is_drum, note.program, note.pitch)) times = [note.end_time for note in notes if not note.is_drum] + [note.start_time for note in notes] values = [ NoteEventData(pitch=note.pitch, velocity=0, program=note.program, is_drum=False) for note in notes if not note.is_drum ] + [ NoteEventData(pitch=note.pitch, velocity=note.velocity, program=note.program, is_drum=note.is_drum) for note in notes ] return times, values def num_velocity_bins_from_codec(codec: Codec): """Get number of velocity bins from event codec.""" lo, hi = codec.event_type_range("velocity") return hi - lo # segment an array into segments of length n def segment(a, n): return [a[i : i + n] for i in range(0, len(a), n)] def velocity_to_bin(velocity, num_velocity_bins): if velocity == 0: return 0 else: return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY) def note_event_data_to_events( state: Optional[NoteEncodingState], value: NoteEventData, codec: Codec, ) -> Sequence[Event]: """Convert note event data to a sequence of events.""" if value.velocity is None: # onsets only, no program or velocity return [Event("pitch", value.pitch)] else: num_velocity_bins = num_velocity_bins_from_codec(codec) velocity_bin = velocity_to_bin(value.velocity, num_velocity_bins) if value.program is None: # onsets + offsets + velocities only, no programs if state is not None: state.active_pitches[(value.pitch, 0)] = velocity_bin return [Event("velocity", velocity_bin), Event("pitch", value.pitch)] else: if value.is_drum: # drum events use a separate vocabulary return [Event("velocity", velocity_bin), Event("drum", value.pitch)] else: # program + velocity + pitch if state is not None: state.active_pitches[(value.pitch, value.program)] = velocity_bin return [ Event("program", value.program), Event("velocity", velocity_bin), Event("pitch", value.pitch), ] def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[Event]: """Output program and pitch events for active notes plus a final tie event.""" events = [] for pitch, program in sorted(state.active_pitches.keys(), key=lambda k: k[::-1]): if state.active_pitches[(pitch, program)]: events += [Event("program", program), Event("pitch", pitch)] events.append(Event("tie", 0)) return events def encode_and_index_events( state, event_times, event_values, codec, frame_times, encode_event_fn, encoding_state_to_events_fn=None ): """Encode a sequence of timed events and index to audio frame times. Encodes time shifts as repeated single step shifts for later run length encoding. Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio frame. This can be used e.g. to prepend events representing the current state to a targets segment. Args: state: Initial event encoding state. event_times: Sequence of event times. event_values: Sequence of event values. encode_event_fn: Function that transforms event value into a sequence of one or more Event objects. codec: An Codec object that maps Event objects to indices. frame_times: Time for every audio frame. encoding_state_to_events_fn: Function that transforms encoding state into a sequence of one or more Event objects. Returns: events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame. Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of another. event_end_indices: Corresponding end event index for every audio frame. Used to ensure when slicing that one chunk ends where the next begins. Should always be true that event_end_indices[i] = event_start_indices[i + 1]. state_events: Encoded "state" events representing the encoding state before each event. state_event_indices: Corresponding state event index for every audio frame. """ indices = np.argsort(event_times, kind="stable") event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices] event_values = [event_values[i] for i in indices] events = [] state_events = [] event_start_indices = [] state_event_indices = [] cur_step = 0 cur_event_idx = 0 cur_state_event_idx = 0 def fill_event_start_indices_to_cur_step(): while ( len(event_start_indices) < len(frame_times) and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second ): event_start_indices.append(cur_event_idx) state_event_indices.append(cur_state_event_idx) for event_step, event_value in zip(event_steps, event_values): while event_step > cur_step: events.append(codec.encode_event(Event(type="shift", value=1))) cur_step += 1 fill_event_start_indices_to_cur_step() cur_event_idx = len(events) cur_state_event_idx = len(state_events) if encoding_state_to_events_fn: # Dump state to state events *before* processing the next event, because # we want to capture the state prior to the occurrence of the event. for e in encoding_state_to_events_fn(state): state_events.append(codec.encode_event(e)) for e in encode_event_fn(state, event_value, codec): events.append(codec.encode_event(e)) # After the last event, continue filling out the event_start_indices array. # The inequality is not strict because if our current step lines up exactly # with (the start of) an audio frame, we need to add an additional shift event # to "cover" that frame. while cur_step / codec.steps_per_second <= frame_times[-1]: events.append(codec.encode_event(Event(type="shift", value=1))) cur_step += 1 fill_event_start_indices_to_cur_step() cur_event_idx = len(events) # Now fill in event_end_indices. We need this extra array to make sure that # when we slice events, each slice ends exactly where the subsequent slice # begins. event_end_indices = event_start_indices[1:] + [len(events)] events = np.array(events).astype(np.int32) state_events = np.array(state_events).astype(np.int32) event_start_indices = segment(np.array(event_start_indices).astype(np.int32), TARGET_FEATURE_LENGTH) event_end_indices = segment(np.array(event_end_indices).astype(np.int32), TARGET_FEATURE_LENGTH) state_event_indices = segment(np.array(state_event_indices).astype(np.int32), TARGET_FEATURE_LENGTH) outputs = [] for start_indices, end_indices, event_indices in zip(event_start_indices, event_end_indices, state_event_indices): outputs.append( { "inputs": events, "event_start_indices": start_indices, "event_end_indices": end_indices, "state_events": state_events, "state_event_indices": event_indices, } ) return outputs def extract_sequence_with_indices(features, state_events_end_token=None, feature_key="inputs"): """Extract target sequence corresponding to audio token segment.""" features = features.copy() start_idx = features["event_start_indices"][0] end_idx = features["event_end_indices"][-1] features[feature_key] = features[feature_key][start_idx:end_idx] if state_events_end_token is not None: # Extract the state events corresponding to the audio start token, and # prepend them to the targets array. state_event_start_idx = features["state_event_indices"][0] state_event_end_idx = state_event_start_idx + 1 while features["state_events"][state_event_end_idx - 1] != state_events_end_token: state_event_end_idx += 1 features[feature_key] = np.concatenate( [ features["state_events"][state_event_start_idx:state_event_end_idx], features[feature_key], ], axis=0, ) return features def map_midi_programs( feature, codec: Codec, granularity_type: str = "full", feature_key: str = "inputs" ) -> Mapping[str, Any]: """Apply MIDI program map to token sequences.""" granularity = PROGRAM_GRANULARITIES[granularity_type] feature[feature_key] = granularity.tokens_map_fn(feature[feature_key], codec) return feature def run_length_encode_shifts_fn( features, codec: Codec, feature_key: str = "inputs", state_change_event_types: Sequence[str] = (), ) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: """Return a function that run-length encodes shifts for a given codec. Args: codec: The Codec to use for shift events. feature_key: The feature key for which to run-length encode shifts. state_change_event_types: A list of event types that represent state changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones will be removed. Returns: A preprocessing function that run-length encodes single-step shifts. """ state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types] def run_length_encode_shifts(features: MutableMapping[str, Any]) -> Mapping[str, Any]: """Combine leading/interior shifts, trim trailing shifts. Args: features: Dict of features to process. Returns: A dict of features. """ events = features[feature_key] shift_steps = 0 total_shift_steps = 0 output = np.array([], dtype=np.int32) current_state = np.zeros(len(state_change_event_ranges), dtype=np.int32) for event in events: if codec.is_shift_event_index(event): shift_steps += 1 total_shift_steps += 1 else: # If this event is a state change and has the same value as the current # state, we can skip it entirely. is_redundant = False for i, (min_index, max_index) in enumerate(state_change_event_ranges): if (min_index <= event) and (event <= max_index): if current_state[i] == event: is_redundant = True current_state[i] = event if is_redundant: continue # Once we've reached a non-shift event, RLE all previous shift events # before outputting the non-shift event. if shift_steps > 0: shift_steps = total_shift_steps while shift_steps > 0: output_steps = np.minimum(codec.max_shift_steps, shift_steps) output = np.concatenate([output, [output_steps]], axis=0) shift_steps -= output_steps output = np.concatenate([output, [event]], axis=0) features[feature_key] = output return features return run_length_encode_shifts(features) def note_representation_processor_chain(features, codec: Codec, note_representation_config: NoteRepresentationConfig): tie_token = codec.encode_event(Event("tie", 0)) state_events_end_token = tie_token if note_representation_config.include_ties else None features = extract_sequence_with_indices( features, state_events_end_token=state_events_end_token, feature_key="inputs" ) features = map_midi_programs(features, codec) features = run_length_encode_shifts_fn(features, codec, state_change_event_types=["velocity", "program"]) return features class MidiProcessor: def __init__(self): self.codec = Codec( max_shift_steps=DEFAULT_MAX_SHIFT_SECONDS * DEFAULT_STEPS_PER_SECOND, steps_per_second=DEFAULT_STEPS_PER_SECOND, event_ranges=[ EventRange("pitch", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), EventRange("velocity", 0, DEFAULT_NUM_VELOCITY_BINS), EventRange("tie", 0, 0), EventRange("program", note_seq.MIN_MIDI_PROGRAM, note_seq.MAX_MIDI_PROGRAM), EventRange("drum", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), ], ) self.tokenizer = Tokenizer(self.codec.num_classes) self.note_representation_config = NoteRepresentationConfig(onsets_only=False, include_ties=True) def __call__(self, midi: Union[bytes, os.PathLike, str]): if not isinstance(midi, bytes): with open(midi, "rb") as f: midi = f.read() ns = note_seq.midi_to_note_sequence(midi) ns_sus = note_seq.apply_sustain_control_changes(ns) for note in ns_sus.notes: if not note.is_drum: note.program = program_to_slakh_program(note.program) samples = np.zeros(int(ns_sus.total_time * SAMPLE_RATE)) _, frame_times = audio_to_frames(samples, HOP_SIZE, FRAME_RATE) times, values = note_sequence_to_onsets_and_offsets_and_programs(ns_sus) events = encode_and_index_events( state=NoteEncodingState(), event_times=times, event_values=values, frame_times=frame_times, codec=self.codec, encode_event_fn=note_event_data_to_events, encoding_state_to_events_fn=note_encoding_state_to_events, ) events = [ note_representation_processor_chain(event, self.codec, self.note_representation_config) for event in events ] input_tokens = [self.tokenizer.encode(event["inputs"]) for event in events] return input_tokens ================================================ FILE: src/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from transformers.modeling_utils import ModuleUtilsMixin from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm from ....configuration_utils import ConfigMixin, register_to_config from ....models import ModelMixin class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): @register_to_config def __init__( self, max_length: int, vocab_size: int, d_model: int, dropout_rate: float, num_layers: int, num_heads: int, d_kv: int, d_ff: int, feed_forward_proj: str, is_decoder: bool = False, ): super().__init__() self.token_embedder = nn.Embedding(vocab_size, d_model) self.position_encoding = nn.Embedding(max_length, d_model) self.position_encoding.weight.requires_grad = False self.dropout_pre = nn.Dropout(p=dropout_rate) t5config = T5Config( vocab_size=vocab_size, d_model=d_model, num_heads=num_heads, d_kv=d_kv, d_ff=d_ff, dropout_rate=dropout_rate, feed_forward_proj=feed_forward_proj, is_decoder=is_decoder, is_encoder_decoder=False, ) self.encoders = nn.ModuleList() for lyr_num in range(num_layers): lyr = T5Block(t5config) self.encoders.append(lyr) self.layer_norm = T5LayerNorm(d_model) self.dropout_post = nn.Dropout(p=dropout_rate) def forward(self, encoder_input_tokens, encoder_inputs_mask): x = self.token_embedder(encoder_input_tokens) seq_length = encoder_input_tokens.shape[1] inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) x += self.position_encoding(inputs_positions) x = self.dropout_pre(x) # inverted the attention mask input_shape = encoder_input_tokens.size() extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) for lyr in self.encoders: x = lyr(x, extended_attention_mask)[0] x = self.layer_norm(x) return self.dropout_post(x), encoder_inputs_mask ================================================ FILE: src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch from ....models import T5FilmDecoder from ....schedulers import DDPMScheduler from ....utils import is_onnx_available, logging from ....utils.torch_utils import randn_tensor if is_onnx_available(): from ...onnx_utils import OnnxRuntimeModel from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline from .continuous_encoder import SpectrogramContEncoder from .notes_encoder import SpectrogramNotesEncoder logger = logging.get_logger(__name__) # pylint: disable=invalid-name TARGET_FEATURE_LENGTH = 256 class SpectrogramDiffusionPipeline(DiffusionPipeline): r""" Pipeline for unconditional audio generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: notes_encoder ([`SpectrogramNotesEncoder`]): continuous_encoder ([`SpectrogramContEncoder`]): decoder ([`T5FilmDecoder`]): A [`T5FilmDecoder`] to denoise the encoded audio latents. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with `decoder` to denoise the encoded audio latents. melgan ([`OnnxRuntimeModel`]): """ _optional_components = ["melgan"] def __init__( self, notes_encoder: SpectrogramNotesEncoder, continuous_encoder: SpectrogramContEncoder, decoder: T5FilmDecoder, scheduler: DDPMScheduler, melgan: OnnxRuntimeModel if is_onnx_available() else Any, ) -> None: super().__init__() # From MELGAN self.min_value = math.log(1e-5) # Matches MelGAN training. self.max_value = 4.0 # Largest value for most examples self.n_dims = 128 self.register_modules( notes_encoder=notes_encoder, continuous_encoder=continuous_encoder, decoder=decoder, scheduler=scheduler, melgan=melgan, ) def scale_features(self, features, output_range=(-1.0, 1.0), clip=False): """Linearly scale features to network outputs range.""" min_out, max_out = output_range if clip: features = torch.clip(features, self.min_value, self.max_value) # Scale to [0, 1]. zero_one = (features - self.min_value) / (self.max_value - self.min_value) # Scale to [min_out, max_out]. return zero_one * (max_out - min_out) + min_out def scale_to_features(self, outputs, input_range=(-1.0, 1.0), clip=False): """Invert by linearly scaling network outputs to features range.""" min_out, max_out = input_range outputs = torch.clip(outputs, min_out, max_out) if clip else outputs # Scale to [0, 1]. zero_one = (outputs - min_out) / (max_out - min_out) # Scale to [self.min_value, self.max_value]. return zero_one * (self.max_value - self.min_value) + self.min_value def encode(self, input_tokens, continuous_inputs, continuous_mask): tokens_mask = input_tokens > 0 tokens_encoded, tokens_mask = self.notes_encoder( encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask ) continuous_encoded, continuous_mask = self.continuous_encoder( encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask ) return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)] def decode(self, encodings_and_masks, input_tokens, noise_time): timesteps = noise_time if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=input_tokens.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(input_tokens.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(input_tokens.shape[0], dtype=timesteps.dtype, device=timesteps.device) logits = self.decoder( encodings_and_masks=encodings_and_masks, decoder_input_tokens=input_tokens, decoder_noise_time=timesteps ) return logits @torch.no_grad() def __call__( self, input_tokens: List[List[int]], generator: Optional[torch.Generator] = None, num_inference_steps: int = 100, return_dict: bool = True, output_type: str = "np", callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, ) -> Union[AudioPipelineOutput, Tuple]: if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) r""" The call function to the pipeline for generation. Args: input_tokens (`List[List[int]]`): generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated audio. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Example: ```py >>> from diffusers import SpectrogramDiffusionPipeline, MidiProcessor >>> pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion") >>> pipe = pipe.to("cuda") >>> processor = MidiProcessor() >>> # Download MIDI from: wget http://www.piano-midi.de/midis/beethoven/beethoven_hammerklavier_2.mid >>> output = pipe(processor("beethoven_hammerklavier_2.mid")) >>> audio = output.audios[0] ``` Returns: [`pipelines.AudioPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated audio. """ pred_mel = np.zeros([1, TARGET_FEATURE_LENGTH, self.n_dims], dtype=np.float32) full_pred_mel = np.zeros([1, 0, self.n_dims], np.float32) ones = torch.ones((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) for i, encoder_input_tokens in enumerate(input_tokens): if i == 0: encoder_continuous_inputs = torch.from_numpy(pred_mel[:1].copy()).to( device=self.device, dtype=self.decoder.dtype ) # The first chunk has no previous context. encoder_continuous_mask = torch.zeros((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) else: # The full song pipeline does not feed in a context feature, so the mask # will be all 0s after the feature converter. Because we know we're # feeding in a full context chunk from the previous prediction, set it # to all 1s. encoder_continuous_mask = ones encoder_continuous_inputs = self.scale_features( encoder_continuous_inputs, output_range=[-1.0, 1.0], clip=True ) encodings_and_masks = self.encode( input_tokens=torch.IntTensor([encoder_input_tokens]).to(device=self.device), continuous_inputs=encoder_continuous_inputs, continuous_mask=encoder_continuous_mask, ) # Sample encoder_continuous_inputs shaped gaussian noise to begin loop x = randn_tensor( shape=encoder_continuous_inputs.shape, generator=generator, device=self.device, dtype=self.decoder.dtype, ) # set step values self.scheduler.set_timesteps(num_inference_steps) # Denoising diffusion loop for j, t in enumerate(self.progress_bar(self.scheduler.timesteps)): output = self.decode( encodings_and_masks=encodings_and_masks, input_tokens=x, noise_time=t / self.scheduler.config.num_train_timesteps, # rescale to [0, 1) ) # Compute previous output: x_t -> x_t-1 x = self.scheduler.step(output, t, x, generator=generator).prev_sample mel = self.scale_to_features(x, input_range=[-1.0, 1.0]) encoder_continuous_inputs = mel[:1] pred_mel = mel.cpu().float().numpy() full_pred_mel = np.concatenate([full_pred_mel, pred_mel[:1]], axis=1) # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, full_pred_mel) logger.info("Generated segment", i) if output_type == "np" and not is_onnx_available(): raise ValueError( "Cannot return output in 'np' format if ONNX is not available. Make sure to have ONNX installed or set 'output_type' to 'mel'." ) elif output_type == "np" and self.melgan is None: raise ValueError( "Cannot return output in 'np' format if melgan component is not defined. Make sure to define `self.melgan` or set 'output_type' to 'mel'." ) if output_type == "np": output = self.melgan(input_features=full_pred_mel.astype(np.float32)) else: output = full_pred_mel if not return_dict: return (output,) return AudioPipelineOutput(audios=output) ================================================ FILE: src/diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py ================================================ from typing import TYPE_CHECKING from ....utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils import dummy_torch_and_transformers_objects _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"] _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"] _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"] _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"] _import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils.dummy_torch_and_transformers_objects import * else: from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ....configuration_utils import FrozenDict from ....image_processor import PipelineImageInput, VaeImageProcessor from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ....models import AutoencoderKL, UNet2DConditionModel from ....models.lora import adjust_lora_scale_text_encoder from ....schedulers import DDIMScheduler from ....utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): # 1. get previous step value (=t-1) prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps if prev_timestep <= 0: return clean_latents # 2. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = ( scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod ) variance = scheduler._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # direction pointing to x_t e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5) dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t noise = std_dev_t * randn_tensor( clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator ) prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise return prev_latents def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): # 1. get previous step value (=t-1) prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps # 2. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = ( scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) # 4. Clip "predicted x_0" if scheduler.config.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = scheduler._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / ( variance ** (0.5) * eta ) return noise class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can only be an instance of [`DDIMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): image = image.to(device=device, dtype=dtype) batch_size = image.shape[0] if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) # add noise to latents using the timestep shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents clean_latents = init_latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents, clean_latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], source_prompt: Union[str, List[str]], image: PipelineImageInput = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, source_guidance_scale: Optional[float] = 1, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image` or tensor representing an image batch to be used as the starting point. Can also accept image latents as `image`, but if passing latents directly it is not encoded again. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. source_guidance_scale (`float`, *optional*, defaults to 1): Guidance scale for the source prompt. This is useful to control the amount of influence the source prompt has for encoding. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Example: ```py import requests import torch from PIL import Image from io import BytesIO from diffusers import CycleDiffusionPipeline, DDIMScheduler # load the pipeline # make sure you're logged in with `huggingface-cli login` model_id_or_path = "CompVis/stable-diffusion-v1-4" scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") # let's download an initial image url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png" response = requests.get(url) init_image = Image.open(BytesIO(response.content)).convert("RGB") init_image = init_image.resize((512, 512)) init_image.save("horse.png") # let's specify a prompt source_prompt = "An astronaut riding a horse" prompt = "An astronaut riding an elephant" # call the pipeline image = pipe( prompt=prompt, source_prompt=source_prompt, image=init_image, num_inference_steps=100, eta=0.1, strength=0.8, guidance_scale=2, source_guidance_scale=1, ).images[0] image.save("horse_to_elephant.png") # let's try another example # See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion url = ( "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png" ) response = requests.get(url) init_image = Image.open(BytesIO(response.content)).convert("RGB") init_image = init_image.resize((512, 512)) init_image.save("black.png") source_prompt = "A black colored car" prompt = "A blue colored car" # call the pipeline torch.manual_seed(0) image = pipe( prompt=prompt, source_prompt=source_prompt, image=init_image, num_inference_steps=100, eta=0.1, strength=0.85, guidance_scale=3, source_guidance_scale=1, ).images[0] image.save("black_to_blue.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 1. Check inputs self.check_inputs(prompt, strength, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds_tuple = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, prompt_embeds=prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) source_prompt_embeds_tuple = self.encode_prompt( source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None, clip_skip=clip_skip ) if prompt_embeds_tuple[1] is not None: prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) else: prompt_embeds = prompt_embeds_tuple[0] if source_prompt_embeds_tuple[1] is not None: source_prompt_embeds = torch.cat([source_prompt_embeds_tuple[1], source_prompt_embeds_tuple[0]]) else: source_prompt_embeds = source_prompt_embeds_tuple[0] # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents, clean_latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) source_latents = latents # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) generator = extra_step_kwargs.pop("generator", None) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents source_latent_model_input = ( torch.cat([source_latents] * 2) if do_classifier_free_guidance else source_latents ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) # predict the noise residual if do_classifier_free_guidance: concat_latent_model_input = torch.stack( [ source_latent_model_input[0], latent_model_input[0], source_latent_model_input[1], latent_model_input[1], ], dim=0, ) concat_prompt_embeds = torch.stack( [ source_prompt_embeds[0], prompt_embeds[0], source_prompt_embeds[1], prompt_embeds[1], ], dim=0, ) else: concat_latent_model_input = torch.cat( [ source_latent_model_input, latent_model_input, ], dim=0, ) concat_prompt_embeds = torch.cat( [ source_prompt_embeds, prompt_embeds, ], dim=0, ) concat_noise_pred = self.unet( concat_latent_model_input, t, cross_attention_kwargs=cross_attention_kwargs, encoder_hidden_states=concat_prompt_embeds, ).sample # perform guidance if do_classifier_free_guidance: ( source_noise_pred_uncond, noise_pred_uncond, source_noise_pred_text, noise_pred_text, ) = concat_noise_pred.chunk(4, dim=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( source_noise_pred_text - source_noise_pred_uncond ) else: (source_noise_pred, noise_pred) = concat_noise_pred.chunk(2, dim=0) # Sample source_latents from the posterior distribution. prev_source_latents = posterior_sample( self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs ) # Compute noise. noise = compute_noise( self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs ) source_latents = prev_source_latents # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs ).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 9. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py ================================================ import inspect from typing import Callable, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ....configuration_utils import FrozenDict from ....schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ....utils import deprecate, logging from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ...pipeline_utils import DiffusionPipeline from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name def preprocess(image): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) return 2.0 * image - 1.0 def preprocess_mask(mask, scale_factor=8): mask = mask.convert("L") w, h = mask.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? mask = 1 - mask # repaint white, keep black return mask class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): r""" Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] _is_onnx = True vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def __call__( self, prompt: Union[str, List[str]], image: Union[np.ndarray, PIL.Image.Image] = None, mask_image: Union[np.ndarray, PIL.Image.Image] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`nd.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. This is the image whose masked region will be inpainted. mask_image (`nd.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if generator is None: generator = np.random # set timesteps self.scheduler.set_timesteps(num_inference_steps) if isinstance(image, PIL.Image.Image): image = preprocess(image) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = prompt_embeds.dtype image = image.astype(latents_dtype) # encode the init image into latents and scale the latents init_latents = self.vae_encoder(sample=image)[0] init_latents = 0.18215 * init_latents # Expand init_latents for batch_size and num_images_per_prompt init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) init_latents_orig = init_latents # preprocess mask if not isinstance(mask_image, np.ndarray): mask_image = preprocess_mask(mask_image, 8) mask_image = mask_image.astype(latents_dtype) mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0) # check sizes if not mask.shape == init_latents.shape: raise ValueError("The mask and image should be the same size!") # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps.numpy()[-init_timestep] timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) # add noise to latents using the timesteps noise = generator.randn(*init_latents.shape).astype(latents_dtype) init_latents = self.scheduler.add_noise( torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) ) init_latents = init_latents.numpy() # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:].numpy() timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ 0 ] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ).prev_sample latents = latents.numpy() init_latents_proper = self.scheduler.add_noise( torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t])) ) init_latents_proper = init_latents_proper.numpy() latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) # There will throw an error if use safety_checker batchsize>1 images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ....configuration_utils import FrozenDict from ....image_processor import VaeImageProcessor from ....loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ....models import AutoencoderKL, UNet2DConditionModel from ....models.lora import adjust_lora_scale_text_encoder from ....schedulers import KarrasDiffusionSchedulers from ....utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline from ...stable_diffusion import StableDiffusionPipelineOutput from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) def preprocess_image(image, batch_size): w, h = image.size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) image = torch.from_numpy(image) return 2.0 * image - 1.0 def preprocess_mask(mask, batch_size, scale_factor=8): if not isinstance(mask, torch.Tensor): mask = mask.convert("L") w, h = mask.size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = np.vstack([mask[None]] * batch_size) mask = 1 - mask # repaint white, keep black mask = torch.from_numpy(mask) return mask else: valid_mask_channel_sizes = [1, 3] # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W) if mask.shape[3] in valid_mask_channel_sizes: mask = mask.permute(0, 3, 1, 2) elif mask.shape[1] not in valid_mask_channel_sizes: raise ValueError( f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension," f" but received mask of shape {tuple(mask.shape)}" ) # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape mask = mask.mean(dim=1, keepdim=True) h, w = mask.shape[-2:] h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8 mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) return mask class StableDiffusionInpaintPipelineLegacy( DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() deprecation_message = ( f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality" "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533" "for more information." ) deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False) if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator): image = image.to(device=device, dtype=dtype) init_latent_dist = self.vae.encode(image).latent_dist init_latents = init_latent_dist.sample(generator=generator) init_latents = self.vae.config.scaling_factor * init_latents # Expand init_latents for batch_size and num_images_per_prompt init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) init_latents_orig = init_latents # add noise to latents using the timesteps noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents, init_latents_orig, noise @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[torch.Tensor, PIL.Image.Image] = None, mask_image: Union[torch.Tensor, PIL.Image.Image] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, add_predicted_noise: Optional[bool] = False, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.Tensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. This is the image whose masked region will be inpainted. mask_image (`torch.Tensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` is 1, the denoising process will be run on the masked area for the full number of iterations specified in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. num_inference_steps (`int`, *optional*, defaults to 50): The reference number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`, as explained above. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. add_predicted_noise (`bool`, *optional*, defaults to True): Use predicted noise instead of random noise when constructing noisy versions of the original image in the reverse diffusion process eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess image and mask if not isinstance(image, torch.Tensor): image = preprocess_image(image, batch_size) mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables # encode the init image into latents and scale the latents latents, init_latents_orig, noise = self.prepare_latents( image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare mask latent mask = mask_image.to(device=device, dtype=latents.dtype) mask = torch.cat([mask] * num_images_per_prompt) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # masking if add_predicted_noise: init_latents_proper = self.scheduler.add_noise( init_latents_orig, noise_pred_uncond, torch.tensor([t]) ) else: init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # use original latents corresponding to unmasked portions of the image latents = (init_latents_orig * mask) + (latents * (1 - mask)) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py ================================================ # Copyright 2024 TIME Authors and The HuggingFace Team. All rights reserved." # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ....image_processor import VaeImageProcessor from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ....models import AutoencoderKL, UNet2DConditionModel from ....models.lora import adjust_lora_scale_text_encoder from ....schedulers import PNDMScheduler from ....schedulers.scheduling_utils import SchedulerMixin from ....utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name AUGS_CONST = ["A photo of ", "An image of ", "A picture of "] class StableDiffusionModelEditingPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin ): r""" Pipeline for text-to-image model editing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. with_to_k ([`bool`]): Whether to edit the key projection matrices along with the value projection matrices. with_augs ([`list`]): Textual augmentations to apply while editing the text-to-image model. Set to `[]` for no augmentations. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: SchedulerMixin, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, with_to_k: bool = True, with_augs: list = AUGS_CONST, ): super().__init__() if isinstance(scheduler, PNDMScheduler): logger.error("PNDMScheduler for this pipeline is currently not supported.") if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.with_to_k = with_to_k self.with_augs = with_augs # get cross-attention layers ca_layers = [] def append_ca(net_): if net_.__class__.__name__ == "CrossAttention": ca_layers.append(net_) elif hasattr(net_, "children"): for net__ in net_.children(): append_ca(net__) # recursively find all cross-attention layers in unet for net in self.unet.named_children(): if "down" in net[0]: append_ca(net[1]) elif "up" in net[0]: append_ca(net[1]) elif "mid" in net[0]: append_ca(net[1]) # get projection matrices self.ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768] self.projection_matrices = [l.to_v for l in self.ca_clip_layers] self.og_matrices = [copy.deepcopy(l.to_v) for l in self.ca_clip_layers] if self.with_to_k: self.projection_matrices = self.projection_matrices + [l.to_k for l in self.ca_clip_layers] self.og_matrices = self.og_matrices + [copy.deepcopy(l.to_k) for l in self.ca_clip_layers] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def edit_model( self, source_prompt: str, destination_prompt: str, lamb: float = 0.1, restart_params: bool = True, ): r""" Apply model editing via closed-form solution (see Eq. 5 in the TIME [paper](https://arxiv.org/abs/2303.08084)). Args: source_prompt (`str`): The source prompt containing the concept to be edited. destination_prompt (`str`): The destination prompt. Must contain all words from `source_prompt` with additional ones to specify the target edit. lamb (`float`, *optional*, defaults to 0.1): The lambda parameter specifying the regularization intesity. Smaller values increase the editing power. restart_params (`bool`, *optional*, defaults to True): Restart the model parameters to their pre-trained version before editing. This is done to avoid edit compounding. When it is `False`, edits accumulate. """ # restart LDM parameters if restart_params: num_ca_clip_layers = len(self.ca_clip_layers) for idx_, l in enumerate(self.ca_clip_layers): l.to_v = copy.deepcopy(self.og_matrices[idx_]) self.projection_matrices[idx_] = l.to_v if self.with_to_k: l.to_k = copy.deepcopy(self.og_matrices[num_ca_clip_layers + idx_]) self.projection_matrices[num_ca_clip_layers + idx_] = l.to_k # set up sentences old_texts = [source_prompt] new_texts = [destination_prompt] # add augmentations base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:] for aug in self.with_augs: old_texts.append(aug + base) base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:] for aug in self.with_augs: new_texts.append(aug + base) # prepare input k* and v* old_embs, new_embs = [], [] for old_text, new_text in zip(old_texts, new_texts): text_input = self.tokenizer( [old_text, new_text], padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] old_emb, new_emb = text_embeddings old_embs.append(old_emb) new_embs.append(new_emb) # identify corresponding destinations for each token in old_emb idxs_replaces = [] for old_text, new_text in zip(old_texts, new_texts): tokens_a = self.tokenizer(old_text).input_ids tokens_b = self.tokenizer(new_text).input_ids tokens_a = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_a] tokens_b = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_b] num_orig_tokens = len(tokens_a) idxs_replace = [] j = 0 for i in range(num_orig_tokens): curr_token = tokens_a[i] while tokens_b[j] != curr_token: j += 1 idxs_replace.append(j) j += 1 while j < 77: idxs_replace.append(j) j += 1 while len(idxs_replace) < 77: idxs_replace.append(76) idxs_replaces.append(idxs_replace) # prepare batch: for each pair of setences, old context and new values contexts, valuess = [], [] for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces): context = old_emb.detach() values = [] with torch.no_grad(): for layer in self.projection_matrices: values.append(layer(new_emb[idxs_replace]).detach()) contexts.append(context) valuess.append(values) # edit the model for layer_num in range(len(self.projection_matrices)): # mat1 = \lambda W + \sum{v k^T} mat1 = lamb * self.projection_matrices[layer_num].weight # mat2 = \lambda I + \sum{k k^T} mat2 = lamb * torch.eye( self.projection_matrices[layer_num].weight.shape[1], device=self.projection_matrices[layer_num].weight.device, ) # aggregate sums for mat1, mat2 for context, values in zip(contexts, valuess): context_vector = context.reshape(context.shape[0], context.shape[1], 1) context_vector_T = context.reshape(context.shape[0], 1, context.shape[1]) value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1) for_mat1 = (value_vector @ context_vector_T).sum(dim=0) for_mat2 = (context_vector @ context_vector_T).sum(dim=0) mat1 += for_mat1 mat2 += for_mat2 # update projection matrix self.projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2)) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: ```py >>> import torch >>> from diffusers import StableDiffusionModelEditingPipeline >>> model_ckpt = "CompVis/stable-diffusion-v1-4" >>> pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt) >>> pipe = pipe.to("cuda") >>> source_prompt = "A pack of roses" >>> destination_prompt = "A pack of blue roses" >>> pipe.edit_model(source_prompt, destination_prompt) >>> prompt = "A field of roses" >>> image = pipe(prompt).images[0] ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py ================================================ # Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ....image_processor import VaeImageProcessor from ....loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ....models import AutoencoderKL, UNet2DConditionModel from ....models.lora import adjust_lora_scale_text_encoder from ....schedulers import KarrasDiffusionSchedulers from ....utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import DDPMParallelScheduler >>> from diffusers import StableDiffusionParadigmsPipeline >>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler") >>> pipe = StableDiffusionParadigmsPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> ngpu, batch_per_device = torch.cuda.device_count(), 5 >>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)]) >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0] ``` """ class StableDiffusionParadigmsPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using a parallelized version of Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # attribute to wrap the unet with torch.nn.DataParallel when running multiple denoising steps on multiple GPUs self.wrapped_unet = self.unet # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _cumsum(self, input, dim, debug=False): if debug: # cumsum_cuda_kernel does not have a deterministic implementation # so perform cumsum on cpu for debugging purposes return torch.cumsum(input.cpu().float(), dim=dim).to(input.device) else: return torch.cumsum(input, dim=dim) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, parallel: int = 10, tolerance: float = 0.1, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, debug: bool = False, clip_skip: int = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. parallel (`int`, *optional*, defaults to 10): The batch size to use when doing parallel sampling. More parallelism may lead to faster inference but requires higher memory usage and can also require more total FLOPs. tolerance (`float`, *optional*, defaults to 0.1): The error tolerance for determining when to slide the batch window forward for parallel sampling. Lower tolerance usually leads to less or no degradation. Higher tolerance is faster but can risk degradation of sample quality. The tolerance is specified as a ratio of the scheduler's noise magnitude. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). debug (`bool`, *optional*, defaults to `False`): Whether or not to run in debug mode. In debug mode, `torch.cumsum` is evaluated using the CPU. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs.pop("generator", None) # # 7. Denoising loop scheduler = self.scheduler parallel = min(parallel, len(scheduler.timesteps)) begin_idx = 0 end_idx = parallel latents_time_evolution_buffer = torch.stack([latents] * (len(scheduler.timesteps) + 1)) # We must make sure the noise of stochastic schedulers such as DDPM is sampled only once per timestep. # Sampling inside the parallel denoising loop will mess this up, so we pre-sample the noise vectors outside the denoising loop. noise_array = torch.zeros_like(latents_time_evolution_buffer) for j in range(len(scheduler.timesteps)): base_noise = randn_tensor( shape=latents.shape, generator=generator, device=latents.device, dtype=prompt_embeds.dtype ) noise = (self.scheduler._get_variance(scheduler.timesteps[j]) ** 0.5) * base_noise noise_array[j] = noise.clone() # We specify the error tolerance as a ratio of the scheduler's noise magnitude. We similarly compute the error tolerance # outside of the denoising loop to avoid recomputing it at every step. # We will be dividing the norm of the noise, so we store its inverse here to avoid a division at every step. inverse_variance_norm = 1.0 / torch.tensor( [scheduler._get_variance(scheduler.timesteps[j]) for j in range(len(scheduler.timesteps))] + [0] ).to(noise_array.device) latent_dim = noise_array[0, 0].numel() inverse_variance_norm = inverse_variance_norm[:, None] / latent_dim scaled_tolerance = tolerance**2 with self.progress_bar(total=num_inference_steps) as progress_bar: steps = 0 while begin_idx < len(scheduler.timesteps): # these have shape (parallel_dim, 2*batch_size, ...) # parallel_len is at most parallel, but could be less if we are at the end of the timesteps # we are processing batch window of timesteps spanning [begin_idx, end_idx) parallel_len = end_idx - begin_idx block_prompt_embeds = torch.stack([prompt_embeds] * parallel_len) block_latents = latents_time_evolution_buffer[begin_idx:end_idx] block_t = scheduler.timesteps[begin_idx:end_idx, None].repeat(1, batch_size * num_images_per_prompt) t_vec = block_t if do_classifier_free_guidance: t_vec = t_vec.repeat(1, 2) # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([block_latents] * 2, dim=1) if do_classifier_free_guidance else block_latents ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t_vec) # if parallel_len is small, no need to use multiple GPUs net = self.wrapped_unet if parallel_len > 3 else self.unet # predict the noise residual, shape is now [parallel_len * 2 * batch_size * num_images_per_prompt, ...] model_output = net( latent_model_input.flatten(0, 1), t_vec.flatten(0, 1), encoder_hidden_states=block_prompt_embeds.flatten(0, 1), cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] per_latent_shape = model_output.shape[1:] if do_classifier_free_guidance: model_output = model_output.reshape( parallel_len, 2, batch_size * num_images_per_prompt, *per_latent_shape ) noise_pred_uncond, noise_pred_text = model_output[:, 0], model_output[:, 1] model_output = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) model_output = model_output.reshape( parallel_len * batch_size * num_images_per_prompt, *per_latent_shape ) block_latents_denoise = scheduler.batch_step_no_noise( model_output=model_output, timesteps=block_t.flatten(0, 1), sample=block_latents.flatten(0, 1), **extra_step_kwargs, ).reshape(block_latents.shape) # back to shape (parallel_dim, batch_size, ...) # now we want to add the pre-sampled noise # parallel sampling algorithm requires computing the cumulative drift from the beginning # of the window, so we need to compute cumulative sum of the deltas and the pre-sampled noises. delta = block_latents_denoise - block_latents cumulative_delta = self._cumsum(delta, dim=0, debug=debug) cumulative_noise = self._cumsum(noise_array[begin_idx:end_idx], dim=0, debug=debug) # if we are using an ODE-like scheduler (like DDIM), we don't want to add noise if scheduler._is_ode_scheduler: cumulative_noise = 0 block_latents_new = ( latents_time_evolution_buffer[begin_idx][None,] + cumulative_delta + cumulative_noise ) cur_error = torch.linalg.norm( (block_latents_new - latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1]).reshape( parallel_len, batch_size * num_images_per_prompt, -1 ), dim=-1, ).pow(2) error_ratio = cur_error * inverse_variance_norm[begin_idx + 1 : end_idx + 1] # find the first index of the vector error_ratio that is greater than error tolerance # we can shift the window for the next iteration up to this index error_ratio = torch.nn.functional.pad( error_ratio, (0, 0, 0, 1), value=1e9 ) # handle the case when everything is below ratio, by padding the end of parallel_len dimension any_error_at_time = torch.max(error_ratio > scaled_tolerance, dim=1).values.int() ind = torch.argmax(any_error_at_time).item() # compute the new begin and end idxs for the window new_begin_idx = begin_idx + min(1 + ind, parallel) new_end_idx = min(new_begin_idx + parallel, len(scheduler.timesteps)) # store the computed latents for the current window in the global buffer latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1] = block_latents_new # initialize the new sliding window latents with the end of the current window, # should be better than random initialization latents_time_evolution_buffer[end_idx : new_end_idx + 1] = latents_time_evolution_buffer[end_idx][ None, ] steps += 1 progress_bar.update(new_begin_idx - begin_idx) if callback is not None and steps % callback_steps == 0: callback(begin_idx, block_t[begin_idx], latents_time_evolution_buffer[begin_idx]) begin_idx = new_begin_idx end_idx = new_end_idx latents = latents_time_evolution_buffer[-1] if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py ================================================ # Copyright 2024 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import ( BlipForConditionalGeneration, BlipProcessor, CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, ) from ....image_processor import PipelineImageInput, VaeImageProcessor from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ....models import AutoencoderKL, UNet2DConditionModel from ....models.attention_processor import Attention from ....models.lora import adjust_lora_scale_text_encoder from ....schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ....schedulers.scheduling_ddim_inverse import DDIMInverseScheduler from ....utils import ( PIL_INTERPOLATION, USE_PEFT_BACKEND, BaseOutput, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): """ Output class for Stable Diffusion pipelines. Args: latents (`torch.Tensor`) inverted latents tensor images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ latents: torch.Tensor images: Union[List[PIL.Image.Image], np.ndarray] EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline >>> def download(embedding_url, local_filepath): ... r = requests.get(embedding_url) ... with open(local_filepath, "wb") as f: ... f.write(r.content) >>> model_ckpt = "CompVis/stable-diffusion-v1-4" >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16) >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.to("cuda") >>> prompt = "a high resolution painting of a cat in the style of van gough" >>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt" >>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt" >>> for url in [source_emb_url, target_emb_url]: ... download(url, url.split("/")[-1]) >>> src_embeds = torch.load(source_emb_url.split("/")[-1]) >>> target_embeds = torch.load(target_emb_url.split("/")[-1]) >>> images = pipeline( ... prompt, ... source_embeds=src_embeds, ... target_embeds=target_embeds, ... num_inference_steps=50, ... cross_attention_guidance_amount=0.15, ... ).images >>> images[0].save("edited_image_dog.png") ``` """ EXAMPLE_INVERT_DOC_STRING = """ Examples: ```py >>> import torch >>> from transformers import BlipForConditionalGeneration, BlipProcessor >>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline >>> import requests >>> from PIL import Image >>> captioner_id = "Salesforce/blip-image-captioning-base" >>> processor = BlipProcessor.from_pretrained(captioner_id) >>> model = BlipForConditionalGeneration.from_pretrained( ... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True ... ) >>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4" >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( ... sd_model_ckpt, ... caption_generator=model, ... caption_processor=processor, ... torch_dtype=torch.float16, ... safety_checker=None, ... ) >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) >>> pipeline.enable_model_cpu_offload() >>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) >>> # generate caption >>> caption = pipeline.generate_caption(raw_image) >>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii" >>> inv_latents = pipeline.invert(caption, image=raw_image).latents >>> # we need to generate source and target embeds >>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] >>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] >>> source_embeds = pipeline.get_embeds(source_prompts) >>> target_embeds = pipeline.get_embeds(target_prompts) >>> # the latents can then be used to edit a real image >>> # when using Stable Diffusion 2 or other models that use v-prediction >>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion >>> image = pipeline( ... caption, ... source_embeds=source_embeds, ... target_embeds=target_embeds, ... num_inference_steps=50, ... cross_attention_guidance_amount=0.15, ... generator=generator, ... latents=inv_latents, ... negative_prompt=caption, ... ).images[0] >>> image.save("edited_image.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image def prepare_unet(unet: UNet2DConditionModel): """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations.""" pix2pix_zero_attn_procs = {} for name in unet.attn_processors.keys(): module_name = name.replace(".processor", "") module = unet.get_submodule(module_name) if "attn2" in name: pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True) module.requires_grad_(True) else: pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False) module.requires_grad_(False) unet.set_attn_processor(pix2pix_zero_attn_procs) return unet class Pix2PixZeroL2Loss: def __init__(self): self.loss = 0.0 def compute_loss(self, predictions, targets): self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0) class Pix2PixZeroAttnProcessor: """An attention processor class to store the attention weights. In Pix2Pix Zero, it happens during computations in the cross-attention blocks.""" def __init__(self, is_pix2pix_zero=False): self.is_pix2pix_zero = is_pix2pix_zero if self.is_pix2pix_zero: self.reference_cross_attn_map = {} def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, timestep=None, loss=None, ): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) if self.is_pix2pix_zero and timestep is not None: # new bookkeeping to save the attention weights. if loss is None: self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu() # compute loss elif loss is not None: prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item()) loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device)) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin): r""" Pipeline for pixel-level image editing using Pix2Pix Zero. Based on Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. requires_safety_checker (bool): Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the pipeline publicly. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = [ "safety_checker", "feature_extractor", "caption_generator", "caption_processor", "inverse_scheduler", ] _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler], feature_extractor: CLIPImageProcessor, safety_checker: StableDiffusionSafetyChecker, inverse_scheduler: DDIMInverseScheduler, caption_generator: BlipForConditionalGeneration, caption_processor: BlipProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, caption_processor=caption_processor, caption_generator=caption_generator, inverse_scheduler=inverse_scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, source_embeds, target_embeds, callback_steps, prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if source_embeds is None and target_embeds is None: raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.") if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def generate_caption(self, images): """Generates caption for a given image.""" text = "a photography of" prev_device = self.caption_generator.device device = self._execution_device inputs = self.caption_processor(images, text, return_tensors="pt").to( device=device, dtype=self.caption_generator.dtype ) self.caption_generator.to(device) outputs = self.caption_generator.generate(**inputs, max_new_tokens=128) # offload caption generator self.caption_generator.to(prev_device) caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] return caption def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor): """Constructs the edit direction to steer the image generation process semantically.""" return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) @torch.no_grad() def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.Tensor: num_prompts = len(prompt) embeds = [] for i in range(0, num_prompts, batch_size): prompt_slice = prompt[i : i + batch_size] input_ids = self.tokenizer( prompt_slice, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ).input_ids input_ids = input_ids.to(self.text_encoder.device) embeds.append(self.text_encoder(input_ids)[0]) return torch.cat(embeds, dim=0).mean(0)[None] def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) if image.shape[1] == 4: latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] latents = torch.cat(latents, dim=0) else: latents = self.vae.encode(image).latent_dist.sample(generator) latents = self.vae.config.scaling_factor * latents if batch_size != latents.shape[0]: if batch_size % latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_latents_per_image = batch_size // latents.shape[0] latents = torch.cat([latents] * additional_latents_per_image, dim=0) else: raise ValueError( f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." ) else: latents = torch.cat([latents], dim=0) return latents def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): pred_type = self.inverse_scheduler.config.prediction_type alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t if pred_type == "epsilon": return model_output elif pred_type == "sample": return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) elif pred_type == "v_prediction": return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" ) def auto_corr_loss(self, hidden_states, generator=None): reg_loss = 0.0 for i in range(hidden_states.shape[0]): for j in range(hidden_states.shape[1]): noise = hidden_states[i : i + 1, j : j + 1, :, :] while True: roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 if noise.shape[2] <= 8: break noise = F.avg_pool2d(noise, kernel_size=2) return reg_loss def kl_divergence(self, hidden_states): mean = hidden_states.mean() var = hidden_states.var() return var + mean**2 - 1 - torch.log(var + 1e-7) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, source_embeds: torch.Tensor = None, target_embeds: torch.Tensor = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, cross_attention_guidance_amount: float = 0.1, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. source_embeds (`torch.Tensor`): Source concept embeddings. Generation of the embeddings as per the [original paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. target_embeds (`torch.Tensor`): Target concept embeddings. Generation of the embeddings as per the [original paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. cross_attention_guidance_amount (`float`, defaults to 0.1): Amount of guidance needed from the reference cross-attention maps. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Define the spatial resolutions. height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, source_embeds, target_embeds, callback_steps, prompt_embeds, ) # 3. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Generate the inverted noise from the input image or any other image # generated from the input prompt. num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) latents_init = latents.clone() # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Rejig the UNet so that we can obtain the cross-attenion maps and # use them for guiding the subsequent image generation. self.unet = prepare_unet(self.unet) # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs={"timestep": t}, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 8. Compute the edit directions. edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device) # 9. Edit the prompt embeddings as per the edit directions discovered. prompt_embeds_edit = prompt_embeds.clone() prompt_embeds_edit[1:2] += edit_direction # 10. Second denoising loop to generate the edited image. self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps latents = latents_init num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # we want to learn the latent such that it steers the generation # process towards the edited direction, so make the make initial # noise learnable x_in = latent_model_input.detach().clone() x_in.requires_grad = True # optimizer opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount) with torch.enable_grad(): # initialize loss loss = Pix2PixZeroL2Loss() # predict the noise residual noise_pred = self.unet( x_in, t, encoder_hidden_states=prompt_embeds_edit.detach(), cross_attention_kwargs={"timestep": t, "loss": loss}, ).sample loss.loss.backward(retain_graph=False) opt.step() # recompute the noise noise_pred = self.unet( x_in.detach(), t, encoder_hidden_states=prompt_embeds_edit, cross_attention_kwargs={"timestep": None}, ).sample latents = x_in.detach().chunk(2)[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) @torch.no_grad() @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) def invert( self, prompt: Optional[str] = None, image: PipelineImageInput = None, num_inference_steps: int = 50, guidance_scale: float = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, cross_attention_guidance_amount: float = 0.1, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, lambda_auto_corr: float = 20.0, lambda_kl: float = 20.0, num_reg_steps: int = 5, num_auto_corr_rolls: int = 5, ): r""" Function used to generate inverted latents given a prompt and image. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be used for conditioning. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 1): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. cross_attention_guidance_amount (`float`, defaults to 0.1): Amount of guidance needed from the reference cross-attention maps. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. lambda_auto_corr (`float`, *optional*, defaults to 20.0): Lambda parameter to control auto correction lambda_kl (`float`, *optional*, defaults to 20.0): Lambda parameter to control Kullback–Leibler divergence output num_reg_steps (`int`, *optional*, defaults to 5): Number of regularization loss steps num_auto_corr_rolls (`int`, *optional*, defaults to 5): Number of auto correction roll steps Examples: Returns: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted latents tensor and then second is the corresponding decoded image. """ # 1. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Preprocess image image = self.image_processor.preprocess(image) # 4. Prepare latent variables latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) # 5. Encode input prompt num_images_per_prompt = 1 prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, prompt_embeds=prompt_embeds, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.inverse_scheduler.timesteps # 6. Rejig the UNet so that we can obtain the cross-attenion maps and # use them for guiding the subsequent image generation. self.unet = prepare_unet(self.unet) # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs={"timestep": t}, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # regularization of the noise prediction with torch.enable_grad(): for _ in range(num_reg_steps): if lambda_auto_corr > 0: for _ in range(num_auto_corr_rolls): var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_ac = self.auto_corr_loss(var_epsilon, generator=generator) l_ac.backward() grad = var.grad.detach() / num_auto_corr_rolls noise_pred = noise_pred - lambda_auto_corr * grad if lambda_kl > 0: var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_kld = self.kl_divergence(var_epsilon) l_kld.backward() grad = var.grad.detach() noise_pred = noise_pred - lambda_kl * grad noise_pred = noise_pred.detach() # compute the previous noisy sample x_t -> x_t-1 latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) inverted_latents = latents.detach().clone() # 8. Post-processing image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (inverted_latents, image) return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image) ================================================ FILE: src/diffusers/pipelines/deprecated/stochastic_karras_ve/__init__.py ================================================ from typing import TYPE_CHECKING from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule _import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]} if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_stochastic_karras_ve import KarrasVePipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) ================================================ FILE: src/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ....models import UNet2DModel from ....schedulers import KarrasVeScheduler from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput class KarrasVePipeline(DiffusionPipeline): r""" Pipeline for unconditional image generation. Parameters: unet ([`UNet2DModel`]): A `UNet2DModel` to denoise the encoded image. scheduler ([`KarrasVeScheduler`]): A scheduler to be used in combination with `unet` to denoise the encoded image. """ # add type hints for linting unet: UNet2DModel scheduler: KarrasVeScheduler def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 50, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: r""" The call function to the pipeline for generation. Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. Example: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ img_size = self.unet.config.sample_size shape = (batch_size, 3, img_size, img_size) model = self.unet # sample x_0 ~ N(0, sigma_0^2 * I) sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): # here sigma_t == t_i from the paper sigma = self.scheduler.schedule[t] sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 # 1. Select temporarily increased noise level sigma_hat # 2. Add new noise to move from sample_i to sample_hat sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) # 3. Predict the noise residual given the noise magnitude `sigma_hat` # The model inputs and output are adjusted by following eq. (213) in [1]. model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample # 4. Evaluate dx/dt at sigma_hat # 5. Take Euler step from sigma to sigma_prev step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) if sigma_prev != 0: # 6. Apply 2nd order correction # The model inputs and output are adjusted by following eq. (213) in [1]. model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample step_output = self.scheduler.step_correct( model_output, sigma_hat, sigma_prev, sample_hat, step_output.prev_sample, step_output["derivative"], ) sample = step_output.prev_sample sample = (sample / 2 + 0.5).clamp(0, 1) image = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/deprecated/versatile_diffusion/__init__.py ================================================ from typing import TYPE_CHECKING from ....utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_transformers_available, is_transformers_version, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils.dummy_torch_and_transformers_objects import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, ) _dummy_objects.update( { "VersatileDiffusionDualGuidedPipeline": VersatileDiffusionDualGuidedPipeline, "VersatileDiffusionImageVariationPipeline": VersatileDiffusionImageVariationPipeline, "VersatileDiffusionPipeline": VersatileDiffusionPipeline, "VersatileDiffusionTextToImagePipeline": VersatileDiffusionTextToImagePipeline, } ) else: _import_structure["modeling_text_unet"] = ["UNetFlatConditionModel"] _import_structure["pipeline_versatile_diffusion"] = ["VersatileDiffusionPipeline"] _import_structure["pipeline_versatile_diffusion_dual_guided"] = ["VersatileDiffusionDualGuidedPipeline"] _import_structure["pipeline_versatile_diffusion_image_variation"] = ["VersatileDiffusionImageVariationPipeline"] _import_structure["pipeline_versatile_diffusion_text_to_image"] = ["VersatileDiffusionTextToImagePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils.dummy_torch_and_transformers_objects import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, ) else: from .pipeline_versatile_diffusion import VersatileDiffusionPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py ================================================ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diffusers.utils import deprecate from ....configuration_utils import ConfigMixin, register_to_config from ....models import ModelMixin from ....models.activations import get_activation from ....models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, AttentionProcessor, AttnAddedKVProcessor, AttnAddedKVProcessor2_0, AttnProcessor, ) from ....models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from ....models.resnet import ResnetBlockCondNorm2D from ....models.transformers.dual_transformer_2d import DualTransformer2DModel from ....models.transformers.transformer_2d import Transformer2DModel from ....models.unets.unet_2d_condition import UNet2DConditionOutput from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ....utils.torch_utils import apply_freeu logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, num_attention_heads, transformer_layers_per_block, attention_type, attention_head_dim, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, dropout=0.0, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlockFlat": return DownBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "CrossAttnDownBlockFlat": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") return CrossAttnDownBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{down_block_type} is not supported.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, num_attention_heads, transformer_layers_per_block, resolution_idx, attention_type, attention_head_dim, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, dropout=0.0, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlockFlat": return UpBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "CrossAttnUpBlockFlat": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") return CrossAttnUpBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{up_block_type} is not supported.") class FourierEmbedder(nn.Module): def __init__(self, num_freqs=64, temperature=100): super().__init__() self.num_freqs = num_freqs self.temperature = temperature freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) freq_bands = freq_bands[None, None, None] self.register_buffer("freq_bands", freq_bands, persistent=False) def __call__(self, x): x = self.freq_bands * x.unsqueeze(-1) return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1) class GLIGENTextBoundingboxProjection(nn.Module): def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8): super().__init__() self.positive_len = positive_len self.out_dim = out_dim self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy if isinstance(out_dim, tuple): out_dim = out_dim[0] if feature_type == "text-only": self.linears = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) elif feature_type == "text-image": self.linears_text = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.linears_image = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) def forward( self, boxes, masks, positive_embeddings=None, phrases_masks=None, image_masks=None, phrases_embeddings=None, image_embeddings=None, ): masks = masks.unsqueeze(-1) xyxy_embedding = self.fourier_embedder(boxes) xyxy_null = self.null_position_feature.view(1, 1, -1) xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null if positive_embeddings: positive_null = self.null_positive_feature.view(1, 1, -1) positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) else: phrases_masks = phrases_masks.unsqueeze(-1) image_masks = image_masks.unsqueeze(-1) text_null = self.null_text_feature.view(1, 1, -1) image_null = self.null_image_feature.view(1, 1, -1) phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) objs = torch.cat([objs_text, objs_image], dim=1) return objs class UNetFlatConditionModel(ModelMixin, ConfigMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): Block type for middle of UNet, it can be one of `UNetMidBlockFlatCrossAttn`, `UNetMidBlockFlat`, or `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"] @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat", ), mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn", up_block_types: Tuple[str] = ( "UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, dropout: float = 0.0, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: for layer_number_per_block in transformer_layers_per_block: if isinstance(layer_number_per_block, list): raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = LinearMultiDim( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlockFlatCrossAttn": self.mid_block = UNetMidBlockFlatCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, dropout=dropout, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, ) elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": self.mid_block = UNetMidBlockFlatSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, dropout=dropout, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type == "UNetMidBlockFlat": self.mid_block = UNetMidBlockFlat( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, dropout=dropout, num_layers=0, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, add_attention=False, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = ( list(reversed(transformer_layers_per_block)) if reverse_transformer_layers_per_block is None else reverse_transformer_layers_per_block ) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resolution_idx=i, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = LinearMultiDim( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) if attention_type in ["gated", "gated-text-image"]: positive_len = 768 if isinstance(cross_attention_dim, int): positive_len = cross_attention_dim elif isinstance(cross_attention_dim, (list, tuple)): positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" self.position_net = GLIGENTextBoundingboxProjection( positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stage blocks where they are being applied. Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. Args: s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ for i, upsample_block in enumerate(self.up_blocks): setattr(upsample_block, "s1", s1) setattr(upsample_block, "s2", s2) setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) def disable_freeu(self): """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): for k in freeu_keys: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def unload_lora(self): """Unloads LoRA weights.""" deprecate( "unload_lora", "0.28.0", "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", ) for module in self.modules(): if hasattr(module, "set_lora_layer"): module.set_lora_layer(None) def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNetFlatConditionModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed through the `self.time_embedding` layer to obtain the timestep embeddings. attention_mask (`torch.Tensor`, *optional*, defaults to `None`): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): A tuple of tensors that if specified are added to the residuals of down unet blocks. mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): additional residuals to be added to UNet long skip connections from down blocks to up blocks for example from ControlNet side model(s) mid_block_additional_residual (`torch.Tensor`, *optional*): additional residual to be added to UNet mid block output, for example from ControlNet side model down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) Returns: [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None for dim in sample.shape[-2:]: if dim % default_overall_up_factor != 0: # Forward upsample size to force interpolation output size. forward_upsample_size = True break # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds) encoder_hidden_states = (encoder_hidden_states, image_embeds) # 2. pre-process sample = self.conv_in(sample) # 2.5 GLIGEN position net if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None # maintain backward compatibility for legacy usage, where # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: deprecate( "T2I should not use down_block_additional_residuals", "1.3.0", "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", standard_warn=False, ) down_intrablock_additional_residuals = down_block_additional_residuals is_adapter = True down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlockFlat additional_residuals = {} if is_adapter and len(down_intrablock_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: sample = self.mid_block(sample, emb) # To support T2I-Adapter-XL if ( is_adapter and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape ): sample += down_intrablock_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, scale=lora_scale, ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) class LinearMultiDim(nn.Linear): def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) if out_features is None: out_features = in_features out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) self.in_features_multidim = in_features self.out_features_multidim = out_features super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) def forward(self, input_tensor, *args, **kwargs): shape = input_tensor.shape n_dim = len(self.in_features_multidim) input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) output_tensor = super().forward(input_tensor) output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) return output_tensor class ResnetBlockFlat(nn.Module): def __init__( self, *, in_channels, out_channels=None, dropout=0.0, temb_channels=512, groups=32, groups_out=None, pre_norm=True, eps=1e-6, time_embedding_norm="default", use_in_shortcut=None, second_dim=4, **kwargs, ): super().__init__() self.pre_norm = pre_norm self.pre_norm = True in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) self.in_channels_prod = np.array(in_channels).prod() self.channels_multidim = in_channels if out_channels is not None: out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) out_channels_prod = np.array(out_channels).prod() self.out_channels_multidim = out_channels else: out_channels_prod = self.in_channels_prod self.out_channels_multidim = self.channels_multidim self.time_embedding_norm = time_embedding_norm if groups_out is None: groups_out = groups self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) if temb_channels is not None: self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) else: self.time_emb_proj = None self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) self.nonlinearity = nn.SiLU() self.use_in_shortcut = ( self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut ) self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = torch.nn.Conv2d( self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 ) def forward(self, input_tensor, temb): shape = input_tensor.shape n_dim = len(self.channels_multidim) input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) if temb is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = input_tensor + hidden_states output_tensor = output_tensor.view(*shape[0:-n_dim], -1) output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) return output_tensor class DownBlockFlat(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ LinearMultiDim( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class CrossAttnDownBlockFlat(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, downsample_padding: int = 1, add_downsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ LinearMultiDim( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, additional_residuals: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: output_states = () blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: hidden_states = hidden_states + additional_residuals output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states # Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim class UpBlockFlat(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_upsample: bool = True, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, upsample_size: Optional[int] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) and getattr(self, "b1", None) and getattr(self, "b2", None) ) for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] # FreeU: Only operate on the first two stages if is_freeu_enabled: hidden_states, res_hidden_states = apply_freeu( self.resolution_idx, hidden_states, res_hidden_states, s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2, ) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states # Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim class CrossAttnUpBlockFlat(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, add_upsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) and getattr(self, "b1", None) and getattr(self, "b2", None) ) for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] # FreeU: Only operate on the first two stages if is_freeu_enabled: hidden_states, res_hidden_states = apply_freeu( self.resolution_idx, hidden_states, res_hidden_states, s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2, ) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states # Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlat(nn.Module): """ A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks. Args: in_channels (`int`): The number of input channels. temb_channels (`int`): The number of temporal embedding channels. dropout (`float`, *optional*, defaults to 0.0): The dropout rate. num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. resnet_time_scale_shift (`str`, *optional*, defaults to `default`): The type of normalization to apply to the time embeddings. This can help to improve the performance of the model on tasks with long-range temporal dependencies. resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. resnet_groups (`int`, *optional*, defaults to 32): The number of groups to use in the group normalization layers of the resnet blocks. attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. resnet_pre_norm (`bool`, *optional*, defaults to `True`): Whether to use pre-normalization for the resnet blocks. add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. attention_head_dim (`int`, *optional*, defaults to 1): Dimension of a single attention head. The number of attention heads is determined based on this value and the number of input channels. output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. Returns: `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, height, width)`. """ def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, attn_groups: Optional[int] = None, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention if attn_groups is None: attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None # there is always at least one resnet if resnet_time_scale_shift == "spatial": resnets = [ ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ] else: resnets = [ ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] if attention_head_dim is None: logger.warning( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." ) attention_head_dim = in_channels for _ in range(num_layers): if self.add_attention: attentions.append( Attention( in_channels, heads=in_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=attn_groups, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) else: attentions.append(None) if resnet_time_scale_shift == "spatial": resnets.append( ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ) else: resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: hidden_states = attn(hidden_states, temb=temb) hidden_states = resnet(hidden_states, temb) return hidden_states # Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, out_channels: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, dual_cross_attention: bool = False, use_linear_projection: bool = False, upcast_attention: bool = False, attention_type: str = "default", ): super().__init__() out_channels = out_channels or in_channels self.in_channels = in_channels self.out_channels = out_channels self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers resnet_groups_out = resnet_groups_out or resnet_groups # there is always at least one resnet resnets = [ ResnetBlockFlat( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, groups_out=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] for i in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlockFlat( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) else: hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) return hidden_states # Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatSimpleCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, skip_time_act: bool = False, only_cross_attention: bool = False, cross_attention_norm: Optional[str] = None, ): super().__init__() self.has_cross_attention = True self.attention_head_dim = attention_head_dim resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.num_heads = in_channels // self.attention_head_dim # there is always at least one resnet resnets = [ ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ] attentions = [] for _ in range(num_layers): processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=in_channels, cross_attention_dim=in_channels, heads=self.num_heads, dim_head=self.attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) # resnet hidden_states = resnet(hidden_states, temb) return hidden_states ================================================ FILE: src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py ================================================ import inspect from typing import Callable, List, Optional, Union import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from ....models import AutoencoderKL, UNet2DConditionModel from ....schedulers import KarrasDiffusionSchedulers from ....utils import logging from ...pipeline_utils import DiffusionPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ tokenizer: CLIPTokenizer image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModel image_encoder: CLIPVisionModel image_unet: UNet2DConditionModel text_unet: UNet2DConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers def __init__( self, tokenizer: CLIPTokenizer, image_feature_extractor: CLIPImageProcessor, text_encoder: CLIPTextModel, image_encoder: CLIPVisionModel, image_unet: UNet2DConditionModel, text_unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( tokenizer=tokenizer, image_feature_extractor=image_feature_extractor, text_encoder=text_encoder, image_encoder=image_encoder, image_unet=image_unet, text_unet=text_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) @torch.no_grad() def image_variation( self, image: Union[torch.Tensor, PIL.Image.Image], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, ): r""" The call function to the pipeline for generation. Args: image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): The image prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> pipe = VersatileDiffusionPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe.image_variation(image, generator=generator).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys() components = {name: component for name, component in self.components.items() if name in expected_components} return VersatileDiffusionImageVariationPipeline(**components)( image=image, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, ) @torch.no_grad() def text_to_image( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide image generation. height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionPipeline >>> import torch >>> pipe = VersatileDiffusionPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0] >>> image.save("./astronaut.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys() components = {name: component for name, component in self.components.items() if name in expected_components} temp_pipeline = VersatileDiffusionTextToImagePipeline(**components) output = temp_pipeline( prompt=prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, ) # swap the attention blocks back to the original state temp_pipeline._swap_unet_attention_blocks() return output @torch.no_grad() def dual_guided( self, prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], image: Union[str, List[str]], text_to_image_strength: float = 0.5, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide image generation. height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> text = "a red car in the sun" >>> pipe = VersatileDiffusionPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> text_to_image_strength = 0.75 >>> image = pipe.dual_guided( ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator ... ).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys() components = {name: component for name, component in self.components.items() if name in expected_components} temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components) output = temp_pipeline( prompt=prompt, image=image, text_to_image_strength=text_to_image_strength, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, ) temp_pipeline._revert_dual_attention() return output ================================================ FILE: src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.utils.checkpoint from transformers import ( CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ....image_processor import VaeImageProcessor from ....models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ....schedulers import KarrasDiffusionSchedulers from ....utils import deprecate, logging from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): r""" Pipeline for image-text dual-guided generation using Versatile Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [`~transformers.BERT`]. tokenizer ([`~transformers.BertTokenizer`]): A `BertTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "bert->unet->vqvae" tokenizer: CLIPTokenizer image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModelWithProjection image_encoder: CLIPVisionModelWithProjection image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers _optional_components = ["text_unet"] def __init__( self, tokenizer: CLIPTokenizer, image_feature_extractor: CLIPImageProcessor, text_encoder: CLIPTextModelWithProjection, image_encoder: CLIPVisionModelWithProjection, image_unet: UNet2DConditionModel, text_unet: UNetFlatConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( tokenizer=tokenizer, image_feature_extractor=image_feature_extractor, text_encoder=text_encoder, image_encoder=image_encoder, image_unet=image_unet, text_unet=text_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None and ( "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention ): # if loading from a universal checkpoint rather than a saved dual-guided pipeline self._convert_to_dual_attention() def remove_unused_weights(self): self.register_modules(text_unet=None) def _convert_to_dual_attention(self): """ Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks from both `image_unet` and `text_unet` """ for name, module in self.image_unet.named_modules(): if isinstance(module, Transformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) image_transformer = self.image_unet.get_submodule(parent_name)[index] text_transformer = self.text_unet.get_submodule(parent_name)[index] config = image_transformer.config dual_transformer = DualTransformer2DModel( num_attention_heads=config.num_attention_heads, attention_head_dim=config.attention_head_dim, in_channels=config.in_channels, num_layers=config.num_layers, dropout=config.dropout, norm_num_groups=config.norm_num_groups, cross_attention_dim=config.cross_attention_dim, attention_bias=config.attention_bias, sample_size=config.sample_size, num_vector_embeds=config.num_vector_embeds, activation_fn=config.activation_fn, num_embeds_ada_norm=config.num_embeds_ada_norm, ) dual_transformer.transformers[0] = image_transformer dual_transformer.transformers[1] = text_transformer self.image_unet.get_submodule(parent_name)[index] = dual_transformer self.image_unet.register_to_config(dual_cross_attention=True) def _revert_dual_attention(self): """ Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline` """ for name, module in self.image_unet.named_modules(): if isinstance(module, DualTransformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] self.image_unet.register_to_config(dual_cross_attention=False) def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not """ def normalize_embeddings(encoder_output): embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) embeds_pooled = encoder_output.text_embeds embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) return embeds batch_size = len(prompt) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids if not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = normalize_embeddings(prompt_embeds) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens = [""] * batch_size max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not """ def normalize_embeddings(encoder_output): embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) embeds = self.image_encoder.visual_projection(embeds) embeds_pooled = embeds[:, 0:1] embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) return embeds batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) image_embeddings = self.image_encoder(pixel_values) image_embeddings = normalize_embeddings(image_embeddings) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) negative_prompt_embeds = self.image_encoder(pixel_values) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and conditional embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, prompt, image, height, width, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}") if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")): for name, module in self.image_unet.named_modules(): if isinstance(module, DualTransformer2DModel): module.mix_ratio = mix_ratio for i, type in enumerate(condition_types): if type == "text": module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings module.transformer_index_for_condition[i] = 1 # use the second (text) transformer else: module.condition_lengths[i] = 257 module.transformer_index_for_condition[i] = 0 # use the first (image) transformer @torch.no_grad() def __call__( self, prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], image: Union[str, List[str]], text_to_image_strength: float = 0.5, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide image generation. height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionDualGuidedPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> text = "a red car in the sun" >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe.remove_unused_weights() >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> text_to_image_strength = 0.75 >>> image = pipe( ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator ... ).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.image_unet.config.sample_size * self.vae_scale_factor width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, image, height, width, callback_steps) # 2. Define call parameters prompt = [prompt] if not isinstance(prompt, list) else prompt image = [image] if not isinstance(image, list) else image batch_size = len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompts prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1) prompt_types = ("text", "image") # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, dual_prompt_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Combine the attention blocks of the image and text UNets self.set_transformer_params(text_to_image_strength, prompt_types) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import numpy as np import PIL.Image import torch import torch.utils.checkpoint from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ....image_processor import VaeImageProcessor from ....models import AutoencoderKL, UNet2DConditionModel from ....schedulers import KarrasDiffusionSchedulers from ....utils import deprecate, logging from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): r""" Pipeline for image variation using Versatile Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [`~transformers.BERT`]. tokenizer ([`~transformers.BertTokenizer`]): A `BertTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "bert->unet->vqvae" image_feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection image_unet: UNet2DConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers def __init__( self, image_feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, image_unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( image_feature_extractor=image_feature_extractor, image_encoder=image_encoder, image_unet=image_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ def normalize_embeddings(encoder_output): embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) embeds = self.image_encoder.visual_projection(embeds) embeds_pooled = embeds[:, 0:1] embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) return embeds if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: prompt = list(prompt) batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) image_embeddings = self.image_encoder(pixel_values) image_embeddings = normalize_embeddings(image_embeddings) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_images: List[str] if negative_prompt is None: uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, PIL.Image.Image): uncond_images = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_images = negative_prompt uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) negative_prompt_embeds = self.image_encoder(pixel_values) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and conditional embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs def check_inputs(self, image, height, width, callback_steps): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, **kwargs, ): r""" The call function to the pipeline for generation. Args: image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): The image prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionImageVariationPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe(image, generator=generator).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.image_unet.config.sample_size * self.vae_scale_factor width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt image_embeddings = self._encode_prompt( image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import torch import torch.utils.checkpoint from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer from ....image_processor import VaeImageProcessor from ....models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ....schedulers import KarrasDiffusionSchedulers from ....utils import deprecate, logging from ....utils.torch_utils import randn_tensor from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Versatile Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [`~transformers.BERT`]. tokenizer ([`~transformers.BertTokenizer`]): A `BertTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "bert->unet->vqvae" tokenizer: CLIPTokenizer image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModelWithProjection image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers _optional_components = ["text_unet"] def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, image_unet: UNet2DConditionModel, text_unet: UNetFlatConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, image_unet=image_unet, text_unet=text_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None: self._swap_unet_attention_blocks() def _swap_unet_attention_blocks(self): """ Swap the `Transformer2DModel` blocks between the image and text UNets """ for name, module in self.image_unet.named_modules(): if isinstance(module, Transformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( self.text_unet.get_submodule(parent_name)[index], self.image_unet.get_submodule(parent_name)[index], ) def remove_unused_weights(self): self.register_modules(text_unet=None) def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ def normalize_embeddings(encoder_output): embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) embeds_pooled = encoder_output.text_embeds embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) return embeds batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids if not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = normalize_embeddings(prompt_embeds) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide image generation. height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionTextToImagePipeline >>> import torch >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe.remove_unused_weights() >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] >>> image.save("./astronaut.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.image_unet.config.sample_size * self.vae_scale_factor width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/deprecated/vq_diffusion/__init__.py ================================================ from typing import TYPE_CHECKING from ....utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils.dummy_torch_and_transformers_objects import ( LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline, ) _dummy_objects.update( { "LearnedClassifierFreeSamplingEmbeddings": LearnedClassifierFreeSamplingEmbeddings, "VQDiffusionPipeline": VQDiffusionPipeline, } ) else: _import_structure["pipeline_vq_diffusion"] = ["LearnedClassifierFreeSamplingEmbeddings", "VQDiffusionPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ....utils.dummy_torch_and_transformers_objects import ( LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline, ) else: from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py ================================================ # Copyright 2024 Microsoft and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTokenizer from ....configuration_utils import ConfigMixin, register_to_config from ....models import ModelMixin, Transformer2DModel, VQModel from ....schedulers import VQDiffusionScheduler from ....utils import logging from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): """ Utility class for storing learned text embeddings for classifier free sampling """ @register_to_config def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): super().__init__() self.learnable = learnable if self.learnable: assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" assert length is not None, "learnable=True requires `length` to be set" embeddings = torch.zeros(length, hidden_size) else: embeddings = None self.embeddings = torch.nn.Parameter(embeddings) class VQDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using VQ Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vqvae ([`VQModel`]): Vector Quantized Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. transformer ([`Transformer2DModel`]): A conditional `Transformer2DModel` to denoise the encoded image latents. scheduler ([`VQDiffusionScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ vqvae: VQModel text_encoder: CLIPTextModel tokenizer: CLIPTokenizer transformer: Transformer2DModel learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings scheduler: VQDiffusionScheduler def __init__( self, vqvae: VQModel, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, transformer: Transformer2DModel, scheduler: VQDiffusionScheduler, learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, ): super().__init__() self.register_modules( vqvae=vqvae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, ) def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) text_input_ids = text_inputs.input_ids if text_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0] # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. # While CLIP does normalize the pooled output of the text transformer when combining # the image and text embeddings, CLIP does not directly normalize the last hidden state. # # CLIP normalizing the pooled output. # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) # duplicate text embeddings for each generation per prompt prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: if self.learned_classifier_free_sampling_embeddings.learnable: negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1) else: uncond_tokens = [""] * batch_size max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # See comment for normalizing text embeddings negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], num_inference_steps: int = 100, guidance_scale: float = 5.0, truncation_rate: float = 1.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, ) -> Union[ImagePipelineOutput, Tuple]: """ The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide image generation. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor` of shape (batch), *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Must be valid embedding indices.If not provided, a latents tensor will be generated of completely masked latent pixels. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # get the initial completely masked latents unless the user supplied it latents_shape = (batch_size, self.transformer.num_latent_pixels) if latents is None: mask_class = self.transformer.num_vector_embeds - 1 latents = torch.full(latents_shape, mask_class).to(self.device) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any(): raise ValueError( "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," f" {self.transformer.num_vector_embeds - 1} (inclusive)." ) latents = latents.to(self.device) # set timesteps self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps_tensor = self.scheduler.timesteps.to(self.device) sample = latents for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the sample if we are doing classifier free guidance latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample # predict the un-noised image # model_output == `log_p_x_0` model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample if do_classifier_free_guidance: model_output_uncond, model_output_text = model_output.chunk(2) model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) model_output = self.truncate(model_output, truncation_rate) # remove `log(0)`'s (`-inf`s) model_output = model_output.clamp(-70) # compute the previous noisy sample x_t -> x_t-1 sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, sample) embedding_channels = self.vqvae.config.vq_embed_dim embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape) image = self.vqvae.decode(embeddings, force_not_quantize=True).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) def truncate(self, log_p_x_0: torch.Tensor, truncation_rate: float) -> torch.Tensor: """ Truncates `log_p_x_0` such that for each column vector, the total cumulative probability is `truncation_rate` The lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero. """ sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True) sorted_p_x_0 = torch.exp(sorted_log_p_x_0) keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate # Ensure that at least the largest probability is not zeroed out all_true = torch.full_like(keep_mask[:, 0:1, :], True) keep_mask = torch.cat((all_true, keep_mask), dim=1) keep_mask = keep_mask[:, :-1, :] keep_mask = keep_mask.gather(1, indices.argsort(1)) rv = log_p_x_0.clone() rv[~keep_mask] = -torch.inf # -inf = log(0) return rv ================================================ FILE: src/diffusers/pipelines/dit/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule _import_structure = {"pipeline_dit": ["DiTPipeline"]} if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .pipeline_dit import DiTPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) ================================================ FILE: src/diffusers/pipelines/dit/pipeline_dit.py ================================================ # Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) # William Peebles and Saining Xie # # Copyright (c) 2021 OpenAI # MIT License # # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Optional, Tuple, Union import torch from ...models import AutoencoderKL, DiTTransformer2DModel from ...schedulers import KarrasDiffusionSchedulers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DiTPipeline(DiffusionPipeline): r""" Pipeline for image generation based on a Transformer backbone instead of a UNet. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: transformer ([`DiTTransformer2DModel`]): A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ model_cpu_offload_seq = "transformer->vae" def __init__( self, transformer: DiTTransformer2DModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, id2label: Optional[Dict[int, str]] = None, ): super().__init__() self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) # create a imagenet -> id dictionary for easier use self.labels = {} if id2label is not None: for key, value in id2label.items(): for label in value.split(","): self.labels[label.lstrip().rstrip()] = int(key) self.labels = dict(sorted(self.labels.items())) def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: r""" Map label strings from ImageNet to corresponding class ids. Parameters: label (`str` or `dict` of `str`): Label strings to be mapped to class ids. Returns: `list` of `int`: Class ids to be processed by pipeline. """ if not isinstance(label, list): label = list(label) for l in label: if l not in self.labels: raise ValueError( f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}." ) return [self.labels[l] for l in label] @torch.no_grad() def __call__( self, class_labels: List[int], guidance_scale: float = 4.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, num_inference_steps: int = 50, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" The call function to the pipeline for generation. Args: class_labels (List[int]): List of ImageNet class labels for the images to be generated. guidance_scale (`float`, *optional*, defaults to 4.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 250): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. Examples: ```py >>> from diffusers import DiTPipeline, DPMSolverMultistepScheduler >>> import torch >>> pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16) >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe = pipe.to("cuda") >>> # pick words from Imagenet class labels >>> pipe.labels # to print all available words >>> # pick words that exist in ImageNet >>> words = ["white shark", "umbrella"] >>> class_ids = pipe.get_label_ids(words) >>> generator = torch.manual_seed(33) >>> output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator) >>> image = output.images[0] # label 'white shark' ``` Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ batch_size = len(class_labels) latent_size = self.transformer.config.sample_size latent_channels = self.transformer.config.in_channels latents = randn_tensor( shape=(batch_size, latent_channels, latent_size, latent_size), generator=generator, device=self._execution_device, dtype=self.transformer.dtype, ) latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1) class_null = torch.tensor([1000] * batch_size, device=self._execution_device) class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels # set step values self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): if guidance_scale > 1: half = latent_model_input[: len(latent_model_input) // 2] latent_model_input = torch.cat([half, half], dim=0) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) timesteps = t if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" if isinstance(timesteps, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(latent_model_input.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(latent_model_input.shape[0]) # predict noise model_output noise_pred = self.transformer( latent_model_input, timestep=timesteps, class_labels=class_labels_input ).sample # perform guidance if guidance_scale > 1: eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) noise_pred = torch.cat([eps, rest], dim=1) # learned sigma if self.transformer.config.out_channels // 2 == latent_channels: model_output, _ = torch.split(noise_pred, latent_channels, dim=1) else: model_output = noise_pred # compute previous image: x_t -> x_t-1 latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample if guidance_scale > 1: latents, _ = latent_model_input.chunk(2, dim=0) else: latents = latent_model_input latents = 1 / self.vae.config.scaling_factor * latents samples = self.vae.decode(latents).sample samples = (samples / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": samples = self.numpy_to_pil(samples) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (samples,) return ImagePipelineOutput(images=samples) ================================================ FILE: src/diffusers/pipelines/flux/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _additional_imports = {} _import_structure = {"pipeline_output": ["FluxPipelineOutput"]} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_flux"] = ["FluxPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_flux import FluxPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) for name, value in _additional_imports.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/flux/pipeline_flux.py ================================================ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import VaeImageProcessor from ...loaders import FluxLoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import FluxPipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import FluxPipeline >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. >>> # Refer to the pipeline documentation for more details. >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] >>> image.save("flux.png") ``` """ def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.16, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin): r""" The Flux pipeline for text-to-image generation. Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Args: transformer ([`FluxTransformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`T5TokenizerFast`): Second Tokenizer of class [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, text_encoder_2: T5EncoderModel, tokenizer_2: T5TokenizerFast, transformer: FluxTransformer2DModel, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) self.default_sample_size = 64 def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer_2( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, ): device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, ): r""" Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # We only use the pooled prompt output from the CLIPTextModel pooled_prompt_embeds = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, ) prompt_embeds = self._get_t5_prompt_embeds( prompt=prompt_2, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) if self.text_encoder is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) return prompt_embeds, pooled_prompt_embeds, text_ids def check_inputs( self, prompt, prompt_2, height, width, prompt_embeds=None, pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height // 2, width // 2, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) latent_image_ids = latent_image_ids.reshape( batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) return latents @staticmethod def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape height = height // vae_scale_factor width = width // vae_scale_factor latents = latents.view(batch_size, height, width, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) return latents def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): height = 2 * (int(height) // self.vae_scale_factor) width = 2 * (int(width) // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height, width) if latents is not None: latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) return latents.to(device=device, dtype=dtype), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) return latents, latent_image_ids @property def guidance_scale(self): return self._guidance_scale @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, invert_image: bool = False, mm_copy_blocks=None, single_copy_blocks=None, mm_skip_blocks=None, single_skip_blocks=None, percentage_of_steps=1.0, inverted_latent_list=None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) ( prompt_embeds, pooled_prompt_embeds, text_ids, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, self.scheduler.config.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) if invert_image: timesteps = reversed(timesteps) result_latent_list = [] if percentage_of_steps != 1: end_timestep_idx = int(len(timesteps) * percentage_of_steps) if invert_image: timesteps = timesteps[:end_timestep_idx] else: timesteps = timesteps[len(timesteps) - end_timestep_idx:] # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) # handle guidance if self.transformer.config.guidance_embeds: if isinstance(guidance_scale, list): guidance = torch.tensor(guidance_scale, device=device) else: guidance = torch.tensor([guidance_scale], device=device) guidance = guidance.expand(latents.shape[0]) else: guidance = None noise_pred = self.transformer( hidden_states=latents, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, step_index=i, mm_copy_blocks=mm_copy_blocks, single_copy_blocks=single_copy_blocks, mm_skip_blocks=mm_skip_blocks, single_skip_blocks=single_skip_blocks, )[0] latents_dtype = latents.dtype if invert_image: # compute the next noisy sample x_t-1 -> x_t latents = self.scheduler.step_forward(noise_pred, t, latents, return_dict=False)[0] result_latent_list.append(latents) else: # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if inverted_latent_list is not None: latents[0] = inverted_latent_list[-i][0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent" or invert_image: image = latents else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if invert_image: return result_latent_list if not return_dict: return (image,) return FluxPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/flux/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Union import numpy as np import PIL.Image from ...utils import BaseOutput @dataclass class FluxPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ images: Union[List[PIL.Image.Image], np.ndarray] ================================================ FILE: src/diffusers/pipelines/free_init_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Tuple, Union import torch import torch.fft as fft from ..utils.torch_utils import randn_tensor class FreeInitMixin: r"""Mixin class for FreeInit.""" def enable_free_init( self, num_iters: int = 3, use_fast_sampling: bool = False, method: str = "butterworth", order: int = 4, spatial_stop_frequency: float = 0.25, temporal_stop_frequency: float = 0.25, ): """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). Args: num_iters (`int`, *optional*, defaults to `3`): Number of FreeInit noise re-initialization iterations. use_fast_sampling (`bool`, *optional*, defaults to `False`): Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. method (`str`, *optional*, defaults to `butterworth`): Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the FreeInit low pass filter. order (`int`, *optional*, defaults to `4`): Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour whereas lower values lead to `gaussian` method behaviour. spatial_stop_frequency (`float`, *optional*, defaults to `0.25`): Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in the original implementation. temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in the original implementation. """ self._free_init_num_iters = num_iters self._free_init_use_fast_sampling = use_fast_sampling self._free_init_method = method self._free_init_order = order self._free_init_spatial_stop_frequency = spatial_stop_frequency self._free_init_temporal_stop_frequency = temporal_stop_frequency def disable_free_init(self): """Disables the FreeInit mechanism if enabled.""" self._free_init_num_iters = None @property def free_init_enabled(self): return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None def _get_free_init_freq_filter( self, shape: Tuple[int, ...], device: Union[str, torch.dtype], filter_type: str, order: float, spatial_stop_frequency: float, temporal_stop_frequency: float, ) -> torch.Tensor: r"""Returns the FreeInit filter based on filter type and other input conditions.""" time, height, width = shape[-3], shape[-2], shape[-1] mask = torch.zeros(shape) if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: return mask if filter_type == "butterworth": def retrieve_mask(x): return 1 / (1 + (x / spatial_stop_frequency**2) ** order) elif filter_type == "gaussian": def retrieve_mask(x): return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) elif filter_type == "ideal": def retrieve_mask(x): return 1 if x <= spatial_stop_frequency * 2 else 0 else: raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") for t in range(time): for h in range(height): for w in range(width): d_square = ( ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2 + (2 * h / height - 1) ** 2 + (2 * w / width - 1) ** 2 ) mask[..., t, h, w] = retrieve_mask(d_square) return mask.to(device) def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor: r"""Noise reinitialization.""" # FFT x_freq = fft.fftn(x, dim=(-3, -2, -1)) x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) # frequency mix high_pass_filter = 1 - low_pass_filter x_freq_low = x_freq * low_pass_filter noise_freq_high = noise_freq * high_pass_filter x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain # IFFT x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real return x_mixed def _apply_free_init( self, latents: torch.Tensor, free_init_iteration: int, num_inference_steps: int, device: torch.device, dtype: torch.dtype, generator: torch.Generator, ): if free_init_iteration == 0: self._free_init_initial_noise = latents.detach().clone() else: latent_shape = latents.shape free_init_filter_shape = (1, *latent_shape[1:]) free_init_freq_filter = self._get_free_init_freq_filter( shape=free_init_filter_shape, device=device, filter_type=self._free_init_method, order=self._free_init_order, spatial_stop_frequency=self._free_init_spatial_stop_frequency, temporal_stop_frequency=self._free_init_temporal_stop_frequency, ) current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1 diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long() z_t = self.scheduler.add_noise( original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device) ).to(dtype=torch.float32) z_rand = randn_tensor( shape=latent_shape, generator=generator, device=device, dtype=torch.float32, ) latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter) latents = latents.to(dtype) # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) if self._free_init_use_fast_sampling: num_inference_steps = max( 1, int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1)) ) if num_inference_steps > 0: self.scheduler.set_timesteps(num_inference_steps, device=device) return latents, self.scheduler.timesteps ================================================ FILE: src/diffusers/pipelines/free_noise_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union import torch from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock from ..models.unets.unet_motion_model import ( CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, ) from ..utils import logging from ..utils.torch_utils import randn_tensor logger = logging.get_logger(__name__) # pylint: disable=invalid-name class AnimateDiffFreeNoiseMixin: r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): r"""Helper function to enable FreeNoise in transformer blocks.""" for motion_module in block.motion_modules: num_transformer_blocks = len(motion_module.transformer_blocks) for i in range(num_transformer_blocks): if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock): motion_module.transformer_blocks[i].set_free_noise_properties( self._free_noise_context_length, self._free_noise_context_stride, self._free_noise_weighting_scheme, ) else: assert isinstance(motion_module.transformer_blocks[i], BasicTransformerBlock) basic_transfomer_block = motion_module.transformer_blocks[i] motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock( dim=basic_transfomer_block.dim, num_attention_heads=basic_transfomer_block.num_attention_heads, attention_head_dim=basic_transfomer_block.attention_head_dim, dropout=basic_transfomer_block.dropout, cross_attention_dim=basic_transfomer_block.cross_attention_dim, activation_fn=basic_transfomer_block.activation_fn, attention_bias=basic_transfomer_block.attention_bias, only_cross_attention=basic_transfomer_block.only_cross_attention, double_self_attention=basic_transfomer_block.double_self_attention, positional_embeddings=basic_transfomer_block.positional_embeddings, num_positional_embeddings=basic_transfomer_block.num_positional_embeddings, context_length=self._free_noise_context_length, context_stride=self._free_noise_context_stride, weighting_scheme=self._free_noise_weighting_scheme, ).to(device=self.device, dtype=self.dtype) motion_module.transformer_blocks[i].load_state_dict( basic_transfomer_block.state_dict(), strict=True ) def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): r"""Helper function to disable FreeNoise in transformer blocks.""" for motion_module in block.motion_modules: num_transformer_blocks = len(motion_module.transformer_blocks) for i in range(num_transformer_blocks): if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock): free_noise_transfomer_block = motion_module.transformer_blocks[i] motion_module.transformer_blocks[i] = BasicTransformerBlock( dim=free_noise_transfomer_block.dim, num_attention_heads=free_noise_transfomer_block.num_attention_heads, attention_head_dim=free_noise_transfomer_block.attention_head_dim, dropout=free_noise_transfomer_block.dropout, cross_attention_dim=free_noise_transfomer_block.cross_attention_dim, activation_fn=free_noise_transfomer_block.activation_fn, attention_bias=free_noise_transfomer_block.attention_bias, only_cross_attention=free_noise_transfomer_block.only_cross_attention, double_self_attention=free_noise_transfomer_block.double_self_attention, positional_embeddings=free_noise_transfomer_block.positional_embeddings, num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings, ).to(device=self.device, dtype=self.dtype) motion_module.transformer_blocks[i].load_state_dict( free_noise_transfomer_block.state_dict(), strict=True ) def _prepare_latents_free_noise( self, batch_size: int, num_channels_latents: int, num_frames: int, height: int, width: int, dtype: torch.dtype, device: torch.device, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, ): if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) context_num_frames = ( self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames ) shape = ( batch_size, num_channels_latents, context_num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) if self._free_noise_noise_type == "random": return latents else: if latents.size(2) == num_frames: return latents elif latents.size(2) != self._free_noise_context_length: raise ValueError( f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}" ) latents = latents.to(device) if self._free_noise_noise_type == "shuffle_context": for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): # ensure window is within bounds window_start = max(0, i - self._free_noise_context_length) window_end = min(num_frames, window_start + self._free_noise_context_stride) window_length = window_end - window_start if window_length == 0: break indices = torch.LongTensor(list(range(window_start, window_end))) shuffled_indices = indices[torch.randperm(window_length, generator=generator)] current_start = i current_end = min(num_frames, current_start + window_length) if current_end == current_start + window_length: # batch of frames perfectly fits the window latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] else: # handle the case where the last batch of frames does not fit perfectly with the window prefix_length = current_end - current_start shuffled_indices = shuffled_indices[:prefix_length] latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] elif self._free_noise_noise_type == "repeat_context": num_repeats = (num_frames + self._free_noise_context_length - 1) // self._free_noise_context_length latents = torch.cat([latents] * num_repeats, dim=2) latents = latents[:, :, :num_frames] return latents def enable_free_noise( self, context_length: Optional[int] = 16, context_stride: int = 4, weighting_scheme: str = "pyramid", noise_type: str = "shuffle_context", ) -> None: r""" Enable long video generation using FreeNoise. Args: context_length (`int`, defaults to `16`, *optional*): The number of video frames to process at once. It's recommended to set this to the maximum frames the Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion adapter config is used. context_stride (`int`, *optional*): Long videos are generated by processing many frames. FreeNoise processes these frames in sliding windows of size `context_length`. Context stride allows you to specify how many frames to skip between each window. For example, a context length of 16 and context stride of 4 would process 24 frames as: [0, 15], [4, 19], [8, 23] (0-based indexing) weighting_scheme (`str`, defaults to `pyramid`): Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting schemes are supported currently: - "pyramid" Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1]. noise_type (`str`, defaults to "shuffle_context"): TODO """ allowed_weighting_scheme = ["pyramid"] allowed_noise_type = ["shuffle_context", "repeat_context", "random"] if context_length > self.motion_adapter.config.motion_max_seq_length: logger.warning( f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results." ) if weighting_scheme not in allowed_weighting_scheme: raise ValueError( f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}" ) if noise_type not in allowed_noise_type: raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}") self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length self._free_noise_context_stride = context_stride self._free_noise_weighting_scheme = weighting_scheme self._free_noise_noise_type = noise_type blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] for block in blocks: self._enable_free_noise_in_block(block) def disable_free_noise(self) -> None: self._free_noise_context_length = None blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] for block in blocks: self._disable_free_noise_in_block(block) @property def free_noise_enabled(self): return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None ================================================ FILE: src/diffusers/pipelines/hunyuandit/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_hunyuandit"] = ["HunyuanDiTPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_hunyuandit import HunyuanDiTPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py ================================================ # Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, HunyuanDiT2DModel from ...models.embeddings import get_2d_rotary_pos_embed from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import DDPMScheduler from ...utils import ( is_torch_xla_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import HunyuanDiTPipeline >>> pipe = HunyuanDiTPipeline.from_pretrained( ... "Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> # You may also use English prompt as HunyuanDiT supports both English and Chinese >>> # prompt = "An astronaut riding a horse" >>> prompt = "一个宇航员在骑马" >>> image = pipe(prompt).images[0] ``` """ STANDARD_RATIO = np.array( [ 1.0, # 1:1 4.0 / 3.0, # 4:3 3.0 / 4.0, # 3:4 16.0 / 9.0, # 16:9 9.0 / 16.0, # 9:16 ] ) STANDARD_SHAPE = [ [(1024, 1024), (1280, 1280)], # 1:1 [(1024, 768), (1152, 864), (1280, 960)], # 4:3 [(768, 1024), (864, 1152), (960, 1280)], # 3:4 [(1280, 768)], # 16:9 [(768, 1280)], # 9:16 ] STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] SUPPORTED_SHAPE = [ (1024, 1024), (1280, 1280), # 1:1 (1024, 768), (1152, 864), (1280, 960), # 4:3 (768, 1024), (864, 1152), (960, 1280), # 3:4 (1280, 768), # 16:9 (768, 1280), # 9:16 ] def map_to_standard_shapes(target_width, target_height): target_ratio = target_width / target_height closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] return width, height def get_resize_crop_region_for_grid(src, tgt_size): th = tw = tgt_size h, w = src r = h / w # resize if r > 1: resize_height = th resize_width = int(round(th / h * w)) else: resize_width = tw resize_height = int(round(tw / w * h)) crop_top = int(round((th - resize_height) / 2.0)) crop_left = int(round((tw - resize_width) / 2.0)) return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class HunyuanDiTPipeline(DiffusionPipeline): r""" Pipeline for English/Chinese-to-image generation using HunyuanDiT. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by ourselves) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use `sdxl-vae-fp16-fix`. text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). HunyuanDiT uses a fine-tuned [bilingual CLIP]. tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): A `BertTokenizer` or `CLIPTokenizer` to tokenize text. transformer ([`HunyuanDiT2DModel`]): The HunyuanDiT model designed by Tencent Hunyuan. text_encoder_2 (`T5EncoderModel`): The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. tokenizer_2 (`MT5Tokenizer`): The tokenizer for the mT5 embedder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = [ "safety_checker", "feature_extractor", "text_encoder_2", "tokenizer_2", "text_encoder", "tokenizer", ] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "prompt_embeds_2", "negative_prompt_embeds_2", ] def __init__( self, vae: AutoencoderKL, text_encoder: BertModel, tokenizer: BertTokenizer, transformer: HunyuanDiT2DModel, scheduler: DDPMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, text_encoder_2=T5EncoderModel, tokenizer_2=MT5Tokenizer, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, text_encoder_2=text_encoder_2, ) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.default_sample_size = ( self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128 ) def encode_prompt( self, prompt: str, device: torch.device = None, dtype: torch.dtype = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: Optional[int] = None, text_encoder_index: int = 0, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device dtype (`torch.dtype`): torch dtype num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the prompt. Required when `prompt_embeds` is passed directly. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. text_encoder_index (`int`, *optional*): Index of the text encoder to use. `0` for clip and `1` for T5. """ if dtype is None: if self.text_encoder_2 is not None: dtype = self.text_encoder_2.dtype elif self.transformer is not None: dtype = self.transformer.dtype else: dtype = None if device is None: device = self._execution_device tokenizers = [self.tokenizer, self.tokenizer_2] text_encoders = [self.text_encoder, self.text_encoder_2] tokenizer = tokenizers[text_encoder_index] text_encoder = text_encoders[text_encoder_index] if max_sequence_length is None: if text_encoder_index == 0: max_length = 77 if text_encoder_index == 1: max_length = 256 else: max_length = max_sequence_length if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = text_encoder( text_input_ids.to(device), attention_mask=prompt_attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, prompt_embeds_2=None, negative_prompt_embeds_2=None, prompt_attention_mask_2=None, negative_prompt_attention_mask_2=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is None and prompt_embeds_2 is None: raise ValueError( "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: raise ValueError( "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: raise ValueError( "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" f" {negative_prompt_embeds_2.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_2: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_2: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, prompt_attention_mask_2: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = (1024, 1024), target_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), use_resolution_binning: bool = True, ): r""" The call function to the pipeline for generation with HunyuanDiT. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`): The height in pixels of the generated image. width (`int`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. prompt_embeds_2 (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. negative_prompt_embeds_2 (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the prompt. Required when `prompt_embeds` is passed directly. prompt_attention_mask_2 (`torch.Tensor`, *optional*): Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A callback function or a list of callback functions to be called at the end of each denoising step. callback_on_step_end_tensor_inputs (`List[str]`, *optional*): A list of tensor inputs that should be passed to the callback function. If not defined, all tensor inputs will be passed. guidance_rescale (`float`, *optional*, defaults to 0.0): Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): The original size of the image. Used to calculate the time ids. target_size (`Tuple[int, int]`, *optional*): The target size of the image. Used to calculate the time ids. crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): The top left coordinates of the crop. Used to calculate the time ids. use_resolution_binning (`bool`, *optional*, defaults to `True`): Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. default height and width height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor height = int((height // 16) * 16) width = int((width // 16) * 16) if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: width, height = map_to_standard_shapes(width, height) height = int(height) width = int(width) logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, prompt_embeds_2, negative_prompt_embeds_2, prompt_attention_mask_2, negative_prompt_attention_mask_2, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt ( prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, device=device, dtype=self.transformer.dtype, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=77, text_encoder_index=0, ) ( prompt_embeds_2, negative_prompt_embeds_2, prompt_attention_mask_2, negative_prompt_attention_mask_2, ) = self.encode_prompt( prompt=prompt, device=device, dtype=self.transformer.dtype, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds_2, negative_prompt_embeds=negative_prompt_embeds_2, prompt_attention_mask=prompt_attention_mask_2, negative_prompt_attention_mask=negative_prompt_attention_mask_2, max_sequence_length=256, text_encoder_index=1, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7 create image_rotary_emb, style embedding & time ids grid_height = height // 8 // self.transformer.config.patch_size grid_width = width // 8 // self.transformer.config.patch_size base_size = 512 // 8 // self.transformer.config.patch_size grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) image_rotary_emb = get_2d_rotary_pos_embed( self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) ) style = torch.tensor([0], device=device) target_size = target_size or (height, width) add_time_ids = list(original_size + target_size + crops_coords_top_left) add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) add_time_ids = torch.cat([add_time_ids] * 2, dim=0) style = torch.cat([style] * 2, dim=0) prompt_embeds = prompt_embeds.to(device=device) prompt_attention_mask = prompt_attention_mask.to(device=device) prompt_embeds_2 = prompt_embeds_2.to(device=device) prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( batch_size * num_images_per_prompt, 1 ) style = style.to(device=device).repeat(batch_size * num_images_per_prompt) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( dtype=latent_model_input.dtype ) # predict the noise residual noise_pred = self.transformer( latent_model_input, t_expand, encoder_hidden_states=prompt_embeds, text_embedding_mask=prompt_attention_mask, encoder_hidden_states_t5=prompt_embeds_2, text_embedding_mask_t5=prompt_attention_mask_2, image_meta_size=add_time_ids, style=style, image_rotary_emb=image_rotary_emb, return_dict=False, )[0] noise_pred, _ = noise_pred.chunk(2, dim=1) # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) negative_prompt_embeds_2 = callback_outputs.pop( "negative_prompt_embeds_2", negative_prompt_embeds_2 ) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/i2vgen_xl/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_i2vgen_xl"] = ["I2VGenXLPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_i2vgen_xl import I2VGenXLPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py ================================================ # Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor from ...models import AutoencoderKL from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet from ...schedulers import DDIMScheduler from ...utils import ( BaseOutput, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import I2VGenXLPipeline >>> from diffusers.utils import export_to_gif, load_image >>> pipeline = I2VGenXLPipeline.from_pretrained( ... "ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16" ... ) >>> pipeline.enable_model_cpu_offload() >>> image_url = ( ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png" ... ) >>> image = load_image(image_url).convert("RGB") >>> prompt = "Papers were floating in the air on a table in the library" >>> negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms" >>> generator = torch.manual_seed(8888) >>> frames = pipeline( ... prompt=prompt, ... image=image, ... num_inference_steps=50, ... negative_prompt=negative_prompt, ... guidance_scale=9.0, ... generator=generator, ... ).frames[0] >>> video_path = export_to_gif(frames, "i2v.gif") ``` """ @dataclass class I2VGenXLPipelineOutput(BaseOutput): r""" Output class for image-to-video pipeline. Args: frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape `(batch_size, num_frames, channels, height, width)` """ frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] class I2VGenXLPipeline( DiffusionPipeline, StableDiffusionMixin, ): r""" Pipeline for image-to-video generation as proposed in [I2VGenXL](https://i2vgen-xl.github.io/). This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): A [`~transformers.CLIPTokenizer`] to tokenize text. unet ([`I2VGenXLUNet`]): A [`I2VGenXLUNet`] to denoise the encoded video latents. scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, image_encoder: CLIPVisionModelWithProjection, feature_extractor: CLIPImageProcessor, unet: I2VGenXLUNet, scheduler: DDIMScheduler, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, image_encoder=image_encoder, feature_extractor=feature_extractor, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # `do_resize=False` as we do custom resizing. self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) @property def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 def encode_prompt( self, prompt, device, num_videos_per_prompt, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_videos_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if self.do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None # Apply clip_skip to negative prompt embeds if clip_skip is None: negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] else: negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. negative_prompt_embeds = negative_prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. negative_prompt_embeds = self.text_encoder.text_model.final_layer_norm(negative_prompt_embeds) if self.do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) return prompt_embeds, negative_prompt_embeds def _encode_image(self, image, device, num_videos_per_prompt): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.video_processor.pil_to_numpy(image) image = self.video_processor.numpy_to_pt(image) # Normalize the image with CLIP training stats. image = self.feature_extractor( images=image, do_normalize=True, do_center_crop=False, do_resize=False, do_rescale=False, return_tensors="pt", ).pixel_values image = image.to(device=device, dtype=dtype) image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.unsqueeze(1) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) if self.do_classifier_free_guidance: negative_image_embeddings = torch.zeros_like(image_embeddings) image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) return image_embeddings def decode_latents(self, latents, decode_chunk_size=None): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) if decode_chunk_size is not None: frames = [] for i in range(0, latents.shape[0], decode_chunk_size): frame = self.vae.decode(latents[i : i + decode_chunk_size]).sample frames.append(frame) image = torch.cat(frames, dim=0) else: image = self.vae.decode(latents).sample decode_shape = (batch_size, num_frames, -1) + image.shape[2:] video = image[None, :].reshape(decode_shape).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) def prepare_image_latents( self, image, device, num_frames, num_videos_per_prompt, ): image = image.to(device=device) image_latents = self.vae.encode(image).latent_dist.sample() image_latents = image_latents * self.vae.config.scaling_factor # Add frames dimension to image latents image_latents = image_latents.unsqueeze(2) # Append a position mask for each subsequent frame # after the intial image latent frame frame_position_mask = [] for frame_idx in range(num_frames - 1): scale = (frame_idx + 1) / (num_frames - 1) frame_position_mask.append(torch.ones_like(image_latents[:, :, :1]) * scale) if frame_position_mask: frame_position_mask = torch.cat(frame_position_mask, dim=2) image_latents = torch.cat([image_latents, frame_position_mask], dim=2) # duplicate image_latents for each generation per prompt, using mps friendly method image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1, 1) if self.do_classifier_free_guidance: image_latents = torch.cat([image_latents] * 2) return image_latents # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, height: Optional[int] = 704, width: Optional[int] = 1280, target_fps: Optional[int] = 16, num_frames: int = 16, num_inference_steps: int = 50, guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, eta: float = 0.0, num_videos_per_prompt: Optional[int] = 1, decode_chunk_size: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = 1, ): r""" The call function to the pipeline for image-to-video generation with [`I2VGenXLPipeline`]. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`): Image or images to guide image generation. If you provide a tensor, it needs to be compatible with [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. target_fps (`int`, *optional*): Frames per second. The rate at which the generated images shall be exported to a video after generation. This is also used as a "micro-condition" while generation. num_frames (`int`, *optional*): The number of video frames to generate. num_inference_steps (`int`, *optional*): The number of denoising steps. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). eta (`float`, *optional*): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. num_videos_per_prompt (`int`, *optional*): The number of images to generate per prompt. decode_chunk_size (`int`, *optional*): The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, image, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. self._guidance_scale = guidance_scale # 3.1 Encode input text prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_videos_per_prompt, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 3.2 Encode image prompt # 3.2.1 Image encodings. # https://github.com/ali-vilab/i2vgen-xl/blob/2539c9262ff8a2a22fa9daecbfd13f0a2dbc32d0/tools/inferences/inference_i2vgen_entrance.py#L114 cropped_image = _center_crop_wide(image, (width, width)) cropped_image = _resize_bilinear( cropped_image, (self.feature_extractor.crop_size["width"], self.feature_extractor.crop_size["height"]) ) image_embeddings = self._encode_image(cropped_image, device, num_videos_per_prompt) # 3.2.2 Image latents. resized_image = _center_crop_wide(image, (width, height)) image = self.video_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype) image_latents = self.prepare_image_latents( image, device=device, num_frames=num_frames, num_videos_per_prompt=num_videos_per_prompt, ) # 3.3 Prepare additional conditions for the UNet. if self.do_classifier_free_guidance: fps_tensor = torch.tensor([target_fps, target_fps]).to(device) else: fps_tensor = torch.tensor([target_fps]).to(device) fps_tensor = fps_tensor.repeat(batch_size * num_videos_per_prompt, 1).ravel() # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, fps=fps_tensor, image_latents=image_latents, image_embeddings=image_embeddings, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # reshape latents batch_size, channel, frames, width, height = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channel, width, height) noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channel, width, height) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # reshape latents back latents = latents[None, :].reshape(batch_size, frames, channel, width, height).permute(0, 2, 1, 3, 4) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # 8. Post processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return I2VGenXLPipelineOutput(frames=video) # The following utilities are taken and adapted from # https://github.com/ali-vilab/i2vgen-xl/blob/main/utils/transforms.py. def _convert_pt_to_pil(image: Union[torch.Tensor, List[torch.Tensor]]): if isinstance(image, list) and isinstance(image[0], torch.Tensor): image = torch.cat(image, 0) if isinstance(image, torch.Tensor): if image.ndim == 3: image = image.unsqueeze(0) image_numpy = VaeImageProcessor.pt_to_numpy(image) image_pil = VaeImageProcessor.numpy_to_pil(image_numpy) image = image_pil return image def _resize_bilinear( image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int] ): # First convert the images to PIL in case they are float tensors (only relevant for tests now). image = _convert_pt_to_pil(image) if isinstance(image, list): image = [u.resize(resolution, PIL.Image.BILINEAR) for u in image] else: image = image.resize(resolution, PIL.Image.BILINEAR) return image def _center_crop_wide( image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int] ): # First convert the images to PIL in case they are float tensors (only relevant for tests now). image = _convert_pt_to_pil(image) if isinstance(image, list): scale = min(image[0].size[0] / resolution[0], image[0].size[1] / resolution[1]) image = [u.resize((round(u.width // scale), round(u.height // scale)), resample=PIL.Image.BOX) for u in image] # center crop x1 = (image[0].width - resolution[0]) // 2 y1 = (image[0].height - resolution[1]) // 2 image = [u.crop((x1, y1, x1 + resolution[0], y1 + resolution[1])) for u in image] return image else: scale = min(image.size[0] / resolution[0], image.size[1] / resolution[1]) image = image.resize((round(image.width // scale), round(image.height // scale)), resample=PIL.Image.BOX) x1 = (image.width - resolution[0]) // 2 y1 = (image.height - resolution[1]) // 2 image = image.crop((x1, y1, x1 + resolution[0], y1 + resolution[1])) return image ================================================ FILE: src/diffusers/pipelines/kandinsky/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_kandinsky"] = ["KandinskyPipeline"] _import_structure["pipeline_kandinsky_combined"] = [ "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", "KandinskyInpaintCombinedPipeline", ] _import_structure["pipeline_kandinsky_img2img"] = ["KandinskyImg2ImgPipeline"] _import_structure["pipeline_kandinsky_inpaint"] = ["KandinskyInpaintPipeline"] _import_structure["pipeline_kandinsky_prior"] = ["KandinskyPriorPipeline", "KandinskyPriorPipelineOutput"] _import_structure["text_encoder"] = ["MultilingualCLIP"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_kandinsky import KandinskyPipeline from .pipeline_kandinsky_combined import ( KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyInpaintCombinedPipeline, ) from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput from .text_encoder import MultilingualCLIP else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Union import torch from transformers import ( XLMRobertaTokenizer, ) from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDIMScheduler, DDPMScheduler from ...utils import ( logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_encoder import MultilingualCLIP logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/Kandinsky-2-1-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> out = pipe_prior(prompt) >>> image_emb = out.image_embeds >>> negative_image_emb = out.negative_image_embeds >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") >>> pipe.to("cuda") >>> image = pipe( ... prompt, ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... ).images >>> image[0].save("cat.png") ``` """ def get_new_h_w(h, w, scale_factor=8): new_h = h // scale_factor**2 if h % scale_factor**2 != 0: new_h += 1 new_w = w // scale_factor**2 if w % scale_factor**2 != 0: new_w += 1 return new_h * scale_factor, new_w * scale_factor class KandinskyPipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ model_cpu_offload_seq = "text_encoder->unet->movq" def __init__( self, text_encoder: MultilingualCLIP, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, DDPMScheduler], movq: VQModel, ): super().__init__() self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", truncation=True, max_length=77, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids.to(device) text_mask = text_inputs.attention_mask.to(device) prompt_embeds, text_encoder_hidden_states = self.text_encoder( input_ids=text_input_ids, attention_mask=text_mask ) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) uncond_text_input_ids = uncond_input.input_ids.to(device) uncond_text_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image_embeds: Union[torch.Tensor, List[torch.Tensor]], negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=prompt_embeds.dtype, device=device ) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps num_channels_latents = self.unet.config.in_channels height, width = get_new_h_w(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, ).prev_sample if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] self.maybe_free_model_hooks() if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Union import PIL.Image import torch from transformers import ( CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, XLMRobertaTokenizer, ) from ...models import PriorTransformer, UNet2DConditionModel, VQModel from ...schedulers import DDIMScheduler, DDPMScheduler, UnCLIPScheduler from ...utils import ( replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from .pipeline_kandinsky import KandinskyPipeline from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline from .pipeline_kandinsky_prior import KandinskyPriorPipeline from .text_encoder import MultilingualCLIP TEXT2IMAGE_EXAMPLE_DOC_STRING = """ Examples: ```py from diffusers import AutoPipelineForText2Image import torch pipe = AutoPipelineForText2Image.from_pretrained( "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload() prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" image = pipe(prompt=prompt, num_inference_steps=25).images[0] ``` """ IMAGE2IMAGE_EXAMPLE_DOC_STRING = """ Examples: ```py from diffusers import AutoPipelineForImage2Image import torch import requests from io import BytesIO from PIL import Image import os pipe = AutoPipelineForImage2Image.from_pretrained( "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload() prompt = "A fantasy landscape, Cinematic lighting" negative_prompt = "low quality, bad quality" url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" response = requests.get(url) image = Image.open(BytesIO(response.content)).convert("RGB") image.thumbnail((768, 768)) image = pipe(prompt=prompt, image=original_image, num_inference_steps=25).images[0] ``` """ INPAINT_EXAMPLE_DOC_STRING = """ Examples: ```py from diffusers import AutoPipelineForInpainting from diffusers.utils import load_image import torch import numpy as np pipe = AutoPipelineForInpainting.from_pretrained( "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload() prompt = "A fantasy landscape, Cinematic lighting" negative_prompt = "low quality, bad quality" original_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" ) mask = np.zeros((768, 768), dtype=np.float32) # Let's mask out an area above the cat's head mask[:250, 250:-250] = 1 image = pipe(prompt=prompt, image=original_image, mask_image=mask, num_inference_steps=25).images[0] ``` """ class KandinskyCombinedPipeline(DiffusionPipeline): """ Combined Pipeline for text-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. prior_prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. prior_image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. prior_text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. prior_tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior_scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ _load_connected_pipes = True model_cpu_offload_seq = "text_encoder->unet->movq->prior_prior->prior_image_encoder->prior_text_encoder" _exclude_from_cpu_offload = ["prior_prior"] def __init__( self, text_encoder: MultilingualCLIP, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, DDPMScheduler], movq: VQModel, prior_prior: PriorTransformer, prior_image_encoder: CLIPVisionModelWithProjection, prior_text_encoder: CLIPTextModelWithProjection, prior_tokenizer: CLIPTokenizer, prior_scheduler: UnCLIPScheduler, prior_image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, prior_prior=prior_prior, prior_image_encoder=prior_image_encoder, prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, prior_scheduler=prior_scheduler, prior_image_processor=prior_image_processor, ) self.prior_pipe = KandinskyPriorPipeline( prior=prior_prior, image_encoder=prior_image_encoder, text_encoder=prior_text_encoder, tokenizer=prior_tokenizer, scheduler=prior_scheduler, image_processor=prior_image_processor, ) self.decoder_pipe = KandinskyPipeline( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, ) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. """ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.enable_model_cpu_offload() def set_progress_bar_config(self, **kwargs): self.prior_pipe.set_progress_bar_config(**kwargs) self.decoder_pipe.set_progress_bar_config(**kwargs) @torch.no_grad() @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, height: int = 512, width: int = 512, prior_guidance_scale: float = 4.0, prior_num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. prior_num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ prior_outputs = self.prior_pipe( prompt=prompt, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, num_inference_steps=prior_num_inference_steps, generator=generator, latents=latents, guidance_scale=prior_guidance_scale, output_type="pt", return_dict=False, ) image_embeds = prior_outputs[0] negative_image_embeds = prior_outputs[1] prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: prompt = (image_embeds.shape[0] // len(prompt)) * prompt outputs = self.decoder_pipe( prompt=prompt, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, width=width, height=height, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=guidance_scale, output_type=output_type, callback=callback, callback_steps=callback_steps, return_dict=return_dict, ) self.maybe_free_model_hooks() return outputs class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline): """ Combined Pipeline for image-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. prior_prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. prior_image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. prior_text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. prior_tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior_scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ _load_connected_pipes = True model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq" _exclude_from_cpu_offload = ["prior_prior"] def __init__( self, text_encoder: MultilingualCLIP, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, DDPMScheduler], movq: VQModel, prior_prior: PriorTransformer, prior_image_encoder: CLIPVisionModelWithProjection, prior_text_encoder: CLIPTextModelWithProjection, prior_tokenizer: CLIPTokenizer, prior_scheduler: UnCLIPScheduler, prior_image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, prior_prior=prior_prior, prior_image_encoder=prior_image_encoder, prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, prior_scheduler=prior_scheduler, prior_image_processor=prior_image_processor, ) self.prior_pipe = KandinskyPriorPipeline( prior=prior_prior, image_encoder=prior_image_encoder, text_encoder=prior_text_encoder, tokenizer=prior_tokenizer, scheduler=prior_scheduler, image_processor=prior_image_processor, ) self.decoder_pipe = KandinskyImg2ImgPipeline( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, ) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.enable_model_cpu_offload() def set_progress_bar_config(self, **kwargs): self.prior_pipe.set_progress_bar_config(**kwargs) self.decoder_pipe.set_progress_bar_config(**kwargs) @torch.no_grad() @replace_example_docstring(IMAGE2IMAGE_EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], negative_prompt: Optional[Union[str, List[str]]] = None, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, strength: float = 0.3, height: int = 512, width: int = 512, prior_guidance_scale: float = 4.0, prior_num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. strength (`float`, *optional*, defaults to 0.3): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. prior_num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ prior_outputs = self.prior_pipe( prompt=prompt, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, num_inference_steps=prior_num_inference_steps, generator=generator, latents=latents, guidance_scale=prior_guidance_scale, output_type="pt", return_dict=False, ) image_embeds = prior_outputs[0] negative_image_embeds = prior_outputs[1] prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt image = [image] if isinstance(prompt, PIL.Image.Image) else image if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: prompt = (image_embeds.shape[0] // len(prompt)) * prompt if ( isinstance(image, (list, tuple)) and len(image) < image_embeds.shape[0] and image_embeds.shape[0] % len(image) == 0 ): image = (image_embeds.shape[0] // len(image)) * image outputs = self.decoder_pipe( prompt=prompt, image=image, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, strength=strength, width=width, height=height, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=guidance_scale, output_type=output_type, callback=callback, callback_steps=callback_steps, return_dict=return_dict, ) self.maybe_free_model_hooks() return outputs class KandinskyInpaintCombinedPipeline(DiffusionPipeline): """ Combined Pipeline for generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. prior_prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. prior_image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. prior_text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. prior_tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior_scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ _load_connected_pipes = True model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq" _exclude_from_cpu_offload = ["prior_prior"] def __init__( self, text_encoder: MultilingualCLIP, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, DDPMScheduler], movq: VQModel, prior_prior: PriorTransformer, prior_image_encoder: CLIPVisionModelWithProjection, prior_text_encoder: CLIPTextModelWithProjection, prior_tokenizer: CLIPTokenizer, prior_scheduler: UnCLIPScheduler, prior_image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, prior_prior=prior_prior, prior_image_encoder=prior_image_encoder, prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, prior_scheduler=prior_scheduler, prior_image_processor=prior_image_processor, ) self.prior_pipe = KandinskyPriorPipeline( prior=prior_prior, image_encoder=prior_image_encoder, text_encoder=prior_text_encoder, tokenizer=prior_tokenizer, scheduler=prior_scheduler, image_processor=prior_image_processor, ) self.decoder_pipe = KandinskyInpaintPipeline( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, ) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.enable_model_cpu_offload() def set_progress_bar_config(self, **kwargs): self.prior_pipe.set_progress_bar_config(**kwargs) self.decoder_pipe.set_progress_bar_config(**kwargs) @torch.no_grad() @replace_example_docstring(INPAINT_EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], mask_image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], negative_prompt: Optional[Union[str, List[str]]] = None, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, height: int = 512, width: int = 512, prior_guidance_scale: float = 4.0, prior_num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. mask_image (`np.array`): Tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. prior_num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ prior_outputs = self.prior_pipe( prompt=prompt, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, num_inference_steps=prior_num_inference_steps, generator=generator, latents=latents, guidance_scale=prior_guidance_scale, output_type="pt", return_dict=False, ) image_embeds = prior_outputs[0] negative_image_embeds = prior_outputs[1] prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt image = [image] if isinstance(prompt, PIL.Image.Image) else image mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: prompt = (image_embeds.shape[0] // len(prompt)) * prompt if ( isinstance(image, (list, tuple)) and len(image) < image_embeds.shape[0] and image_embeds.shape[0] % len(image) == 0 ): image = (image_embeds.shape[0] // len(image)) * image if ( isinstance(mask_image, (list, tuple)) and len(mask_image) < image_embeds.shape[0] and image_embeds.shape[0] % len(mask_image) == 0 ): mask_image = (image_embeds.shape[0] // len(mask_image)) * mask_image outputs = self.decoder_pipe( prompt=prompt, image=image, mask_image=mask_image, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, width=width, height=height, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=guidance_scale, output_type=output_type, callback=callback, callback_steps=callback_steps, return_dict=return_dict, ) self.maybe_free_model_hooks() return outputs ================================================ FILE: src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Union import numpy as np import PIL.Image import torch from PIL import Image from transformers import ( XLMRobertaTokenizer, ) from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDIMScheduler from ...utils import ( logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_encoder import MultilingualCLIP logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "A red cartoon frog, 4k" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyImg2ImgPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/frog.png" ... ) >>> image = pipe( ... prompt, ... image=init_image, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... strength=0.2, ... ).images >>> image[0].save("red_frog.png") ``` """ def get_new_h_w(h, w, scale_factor=8): new_h = h // scale_factor**2 if h % scale_factor**2 != 0: new_h += 1 new_w = w // scale_factor**2 if w % scale_factor**2 != 0: new_w += 1 return new_h * scale_factor, new_w * scale_factor def prepare_image(pil_image, w=512, h=512): pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) arr = np.array(pil_image.convert("RGB")) arr = arr.astype(np.float32) / 127.5 - 1 arr = np.transpose(arr, [2, 0, 1]) image = torch.from_numpy(arr).unsqueeze(0) return image class KandinskyImg2ImgPipeline(DiffusionPipeline): """ Pipeline for image-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ image encoder and decoder """ model_cpu_offload_seq = "text_encoder->unet->movq" def __init__( self, text_encoder: MultilingualCLIP, movq: VQModel, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, ): super().__init__() self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start def prepare_latents(self, latents, latent_timestep, shape, dtype, device, generator, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma shape = latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self.add_noise(latents, noise, latent_timestep) return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids.to(device) text_mask = text_inputs.attention_mask.to(device) prompt_embeds, text_encoder_hidden_states = self.text_encoder( input_ids=text_input_ids, attention_mask=text_mask ) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) uncond_text_input_ids = uncond_input.input_ids.to(device) uncond_text_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask # add_noise method to overwrite the one in schedule because it use a different beta schedule for adding noise vs sampling def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: betas = torch.linspace(0.0001, 0.02, 1000, dtype=torch.float32) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], image_embeds: torch.Tensor, negative_image_embeds: torch.Tensor, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, strength: float = 0.3, guidance_scale: float = 7.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.Tensor`, `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. strength (`float`, *optional*, defaults to 0.3): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ # 1. Define call parameters if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 # 2. get text and image embeddings prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=prompt_embeds.dtype, device=device ) # 3. pre-processing initial image if not isinstance(image, list): image = [image] if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" ) image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = image.to(dtype=prompt_embeds.dtype, device=device) latents = self.movq.encode(image)["latents"] latents = latents.repeat_interleave(num_images_per_prompt, dim=0) # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # the formular to calculate timestep for add_noise is taken from the original kandinsky repo latent_timestep = int(self.scheduler.config.num_train_timesteps * strength) - 2 latent_timestep = torch.tensor([latent_timestep] * batch_size, dtype=timesteps_tensor.dtype, device=device) num_channels_latents = self.unet.config.in_channels height, width = get_new_h_w(height, width, self.movq_scale_factor) # 5. Create initial latent latents = self.prepare_latents( latents, latent_timestep, (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, device, generator, self.scheduler, ) # 6. Denoising loop for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, ).prev_sample if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 7. post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] self.maybe_free_model_hooks() if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from typing import Callable, List, Optional, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from packaging import version from PIL import Image from transformers import ( XLMRobertaTokenizer, ) from ... import __version__ from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDIMScheduler from ...utils import ( logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_encoder import MultilingualCLIP logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> import numpy as np >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "a hat" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyInpaintPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> mask = np.zeros((768, 768), dtype=np.float32) >>> mask[:250, 250:-250] = 1 >>> out = pipe( ... prompt, ... image=init_image, ... mask_image=mask, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ) >>> image = out.images[0] >>> image.save("cat_with_hat.png") ``` """ def get_new_h_w(h, w, scale_factor=8): new_h = h // scale_factor**2 if h % scale_factor**2 != 0: new_h += 1 new_w = w // scale_factor**2 if w % scale_factor**2 != 0: new_w += 1 return new_h * scale_factor, new_w * scale_factor def prepare_mask(masks): prepared_masks = [] for mask in masks: old_mask = deepcopy(mask) for i in range(mask.shape[1]): for j in range(mask.shape[2]): if old_mask[0][i][j] == 1: continue if i != 0: mask[:, i - 1, j] = 0 if j != 0: mask[:, i, j - 1] = 0 if i != 0 and j != 0: mask[:, i - 1, j - 1] = 0 if i != mask.shape[1] - 1: mask[:, i + 1, j] = 0 if j != mask.shape[2] - 1: mask[:, i, j + 1] = 0 if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: mask[:, i + 1, j + 1] = 0 prepared_masks.append(mask) return torch.stack(prepared_masks, dim=0) def prepare_mask_and_masked_image(image, mask, height, width): r""" Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if image is None: raise ValueError("`image` input cannot be undefined.") if mask is None: raise ValueError("`mask_image` input cannot be undefined.") if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) mask = 1 - mask return mask, image class KandinskyInpaintPipeline(DiffusionPipeline): """ Pipeline for text-guided image inpainting using Kandinsky2.1 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ image encoder and decoder """ model_cpu_offload_seq = "text_encoder->unet->movq" def __init__( self, text_encoder: MultilingualCLIP, movq: VQModel, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, ): super().__init__() self.register_modules( text_encoder=text_encoder, movq=movq, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) self._warn_has_been_called = False # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids.to(device) text_mask = text_inputs.attention_mask.to(device) prompt_embeds, text_encoder_hidden_states = self.text_encoder( input_ids=text_input_ids, attention_mask=text_mask ) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) uncond_text_input_ids = uncond_input.input_ids.to(device) uncond_text_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.Tensor, PIL.Image.Image], mask_image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], image_embeds: torch.Tensor, negative_image_embeds: torch.Tensor, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.Tensor`, `PIL.Image.Image` or `np.ndarray`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. mask_image (`PIL.Image.Image`,`torch.Tensor` or `np.ndarray`): `Image`, or a tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. You can pass a pytorch tensor as mask only if the image you passed is a pytorch tensor, and it should contain one color channel (L) instead of 3, so the expected shape would be either `(B, 1, H, W,)`, `(B, H, W)`, `(1, H, W)` or `(H, W)` If image is an PIL image or numpy array, mask should also be a either PIL image or numpy array. If it is a PIL image, it will be converted to a single channel (luminance) before use. If it is a nummpy array, the expected shape is `(H, W)`. image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ if not self._warn_has_been_called and version.parse(version.parse(__version__).base_version) < version.parse( "0.23.0.dev0" ): logger.warning( "Please note that the expected format of `mask_image` has recently been changed. " "Before diffusers == 0.19.0, Kandinsky Inpainting pipelines repainted black pixels and preserved black pixels. " "As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. " "This way, Kandinsky's masking behavior is aligned with Stable Diffusion. " "THIS means that you HAVE to invert the input mask to have the same behavior as before as explained in https://github.com/huggingface/diffusers/pull/4207. " "This warning will be surpressed after the first inference call and will be removed in diffusers>0.23.0" ) self._warn_has_been_called = True # Define call parameters if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=prompt_embeds.dtype, device=device ) # preprocess image and mask mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) image = image.to(dtype=prompt_embeds.dtype, device=device) image = self.movq.encode(image)["latents"] mask_image = mask_image.to(dtype=prompt_embeds.dtype, device=device) image_shape = tuple(image.shape[-2:]) mask_image = F.interpolate( mask_image, image_shape, mode="nearest", ) mask_image = prepare_mask(mask_image) masked_image = image * mask_image mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: mask_image = mask_image.repeat(2, 1, 1, 1) masked_image = masked_image.repeat(2, 1, 1, 1) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps num_channels_latents = self.movq.config.latent_channels # get h, w for latents sample_height, sample_width = get_new_h_w(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, sample_height, sample_width), text_encoder_hidden_states.dtype, device, generator, latents, self.scheduler, ) # Check that sizes of mask, masked image and latents match with expected num_channels_mask = mask_image.shape[1] num_channels_masked_image = masked_image.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, ).prev_sample if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] self.maybe_free_model_hooks() if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer from ...schedulers import UnCLIPScheduler from ...utils import ( BaseOutput, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> out = pipe_prior(prompt) >>> image_emb = out.image_embeds >>> negative_image_emb = out.negative_image_embeds >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") >>> pipe.to("cuda") >>> image = pipe( ... prompt, ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... ).images >>> image[0].save("cat.png") ``` """ EXAMPLE_INTERPOLATE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline >>> from diffusers.utils import load_image >>> import PIL >>> import torch >>> from torchvision import transforms >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> img1 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> img2 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/starry_night.jpeg" ... ) >>> images_texts = ["a cat", img1, img2] >>> weights = [0.3, 0.3, 0.4] >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) >>> pipe.to("cuda") >>> image = pipe( ... "", ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=150, ... ).images[0] >>> image.save("starry_cat.png") ``` """ @dataclass class KandinskyPriorPipelineOutput(BaseOutput): """ Output class for KandinskyPriorPipeline. Args: image_embeds (`torch.Tensor`) clip image embeddings for text prompt negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`) clip image embeddings for unconditional tokens """ image_embeds: Union[torch.Tensor, np.ndarray] negative_image_embeds: Union[torch.Tensor, np.ndarray] class KandinskyPriorPipeline(DiffusionPipeline): """ Pipeline for generating image prior for Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ _exclude_from_cpu_offload = ["prior"] model_cpu_offload_seq = "text_encoder->prior" def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModelWithProjection, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: UnCLIPScheduler, image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, image_encoder=image_encoder, image_processor=image_processor, ) @torch.no_grad() @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) def interpolate( self, images_and_prompts: List[Union[str, PIL.Image.Image, torch.Tensor]], weights: List[float], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, negative_prior_prompt: Optional[str] = None, negative_prompt: str = "", guidance_scale: float = 4.0, device=None, ): """ Function invoked when using the prior pipeline for interpolation. Args: images_and_prompts (`List[Union[str, PIL.Image.Image, torch.Tensor]]`): list of prompts and images to guide the image generation. weights: (`List[float]`): list of weights for each condition in `images_and_prompts` num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. negative_prior_prompt (`str`, *optional*): The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ device = device or self.device if len(images_and_prompts) != len(weights): raise ValueError( f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" ) image_embeddings = [] for cond, weight in zip(images_and_prompts, weights): if isinstance(cond, str): image_emb = self( cond, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ).image_embeds elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): if isinstance(cond, PIL.Image.Image): cond = ( self.image_processor(cond, return_tensors="pt") .pixel_values[0] .unsqueeze(0) .to(dtype=self.image_encoder.dtype, device=device) ) image_emb = self.image_encoder(cond)["image_embeds"] else: raise ValueError( f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" ) image_embeddings.append(image_emb * weight) image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True) out_zero = self( negative_prompt, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ) zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def get_zero_embed(self, batch_size=1, device=None): device = device or self.device zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( device=device, dtype=self.image_encoder.dtype ) zero_image_emb = self.image_encoder(zero_img)["image_embeds"] zero_image_emb = zero_image_emb.repeat(batch_size, 1) return zero_image_emb def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, guidance_scale: float = 4.0, output_type: Optional[str] = "pt", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] elif not isinstance(negative_prompt, list) and negative_prompt is not None: raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") # if the negative prompt is defined we double the batch size to # directly retrieve the negative prompt embedding if negative_prompt is not None: prompt = prompt + negative_prompt negative_prompt = 2 * negative_prompt device = self._execution_device batch_size = len(prompt) batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) prior_timesteps_tensor = self.scheduler.timesteps embedding_dim = self.prior.config.embedding_dim latents = self.prepare_latents( (batch_size, embedding_dim), prompt_embeds.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == prior_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = prior_timesteps_tensor[i + 1] latents = self.scheduler.step( predicted_image_embedding, timestep=t, sample=latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample latents = self.prior.post_process_latents(latents) image_embeddings = latents # if negative prompt has been defined, we retrieve split the image embedding into two if negative_prompt is None: zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) self.maybe_free_model_hooks() else: image_embeddings, zero_embeds = image_embeddings.chunk(2) if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.prior_hook.offload() if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") if output_type == "np": image_embeddings = image_embeddings.cpu().numpy() zero_embeds = zero_embeds.cpu().numpy() if not return_dict: return (image_embeddings, zero_embeds) return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) ================================================ FILE: src/diffusers/pipelines/kandinsky/text_encoder.py ================================================ import torch from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel class MCLIPConfig(XLMRobertaConfig): model_type = "M-CLIP" def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs): self.transformerDimensions = transformerDimSize self.numDims = imageDimSize super().__init__(**kwargs) class MultilingualCLIP(PreTrainedModel): config_class = MCLIPConfig def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) self.transformer = XLMRobertaModel(config) self.LinearTransformation = torch.nn.Linear( in_features=config.transformerDimensions, out_features=config.numDims ) def forward(self, input_ids, attention_mask): embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None] return self.LinearTransformation(embs2), embs ================================================ FILE: src/diffusers/pipelines/kandinsky2_2/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_kandinsky2_2"] = ["KandinskyV22Pipeline"] _import_structure["pipeline_kandinsky2_2_combined"] = [ "KandinskyV22CombinedPipeline", "KandinskyV22Img2ImgCombinedPipeline", "KandinskyV22InpaintCombinedPipeline", ] _import_structure["pipeline_kandinsky2_2_controlnet"] = ["KandinskyV22ControlnetPipeline"] _import_structure["pipeline_kandinsky2_2_controlnet_img2img"] = ["KandinskyV22ControlnetImg2ImgPipeline"] _import_structure["pipeline_kandinsky2_2_img2img"] = ["KandinskyV22Img2ImgPipeline"] _import_structure["pipeline_kandinsky2_2_inpainting"] = ["KandinskyV22InpaintPipeline"] _import_structure["pipeline_kandinsky2_2_prior"] = ["KandinskyV22PriorPipeline"] _import_structure["pipeline_kandinsky2_2_prior_emb2emb"] = ["KandinskyV22PriorEmb2EmbPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_kandinsky2_2 import KandinskyV22Pipeline from .pipeline_kandinsky2_2_combined import ( KandinskyV22CombinedPipeline, KandinskyV22Img2ImgCombinedPipeline, KandinskyV22InpaintCombinedPipeline, ) from .pipeline_kandinsky2_2_controlnet import KandinskyV22ControlnetPipeline from .pipeline_kandinsky2_2_controlnet_img2img import KandinskyV22ControlnetImg2ImgPipeline from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline from .pipeline_kandinsky2_2_prior_emb2emb import KandinskyV22PriorEmb2EmbPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Dict, List, Optional, Union import torch from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDPMScheduler from ...utils import deprecate, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline >>> import torch >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> out = pipe_prior(prompt) >>> image_emb = out.image_embeds >>> zero_image_emb = out.negative_image_embeds >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ).images >>> image[0].save("cat.png") ``` """ def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor class KandinskyV22Pipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ model_cpu_offload_seq = "unet->movq" _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds"] def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeds: Union[torch.Tensor, List[torch.Tensor]], negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): """ Function invoked when calling the pipeline for generation. Args: image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) device = self._execution_device self._guidance_scale = guidance_scale if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) batch_size = image_embeds.shape[0] * num_images_per_prompt if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if self.do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=self.unet.dtype, device=device ) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps num_channels_latents = self.unet.config.in_channels height, width = downscale_height_and_width(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), image_embeds.dtype, device, generator, latents, self.scheduler, ) self._num_timesteps = len(timesteps) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if self.do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) image_embeds = callback_outputs.pop("image_embeds", image_embeds) negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds) if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if not output_type == "latent": # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) else: image = latents self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Dict, List, Optional, Union import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer, UNet2DConditionModel, VQModel from ...schedulers import DDPMScheduler, UnCLIPScheduler from ...utils import deprecate, logging, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from .pipeline_kandinsky2_2 import KandinskyV22Pipeline from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name TEXT2IMAGE_EXAMPLE_DOC_STRING = """ Examples: ```py from diffusers import AutoPipelineForText2Image import torch pipe = AutoPipelineForText2Image.from_pretrained( "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload() prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" image = pipe(prompt=prompt, num_inference_steps=25).images[0] ``` """ IMAGE2IMAGE_EXAMPLE_DOC_STRING = """ Examples: ```py from diffusers import AutoPipelineForImage2Image import torch import requests from io import BytesIO from PIL import Image import os pipe = AutoPipelineForImage2Image.from_pretrained( "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload() prompt = "A fantasy landscape, Cinematic lighting" negative_prompt = "low quality, bad quality" url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" response = requests.get(url) image = Image.open(BytesIO(response.content)).convert("RGB") image.thumbnail((768, 768)) image = pipe(prompt=prompt, image=original_image, num_inference_steps=25).images[0] ``` """ INPAINT_EXAMPLE_DOC_STRING = """ Examples: ```py from diffusers import AutoPipelineForInpainting from diffusers.utils import load_image import torch import numpy as np pipe = AutoPipelineForInpainting.from_pretrained( "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16 ) pipe.enable_model_cpu_offload() prompt = "A fantasy landscape, Cinematic lighting" negative_prompt = "low quality, bad quality" original_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" ) mask = np.zeros((768, 768), dtype=np.float32) # Let's mask out an area above the cat's head mask[:250, 250:-250] = 1 image = pipe(prompt=prompt, image=original_image, mask_image=mask, num_inference_steps=25).images[0] ``` """ class KandinskyV22CombinedPipeline(DiffusionPipeline): """ Combined Pipeline for text-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. prior_prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. prior_image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. prior_text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. prior_tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior_scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. prior_image_processor ([`CLIPImageProcessor`]): A image_processor to be used to preprocess image from clip. """ model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq" _load_connected_pipes = True _exclude_from_cpu_offload = ["prior_prior"] def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, prior_prior: PriorTransformer, prior_image_encoder: CLIPVisionModelWithProjection, prior_text_encoder: CLIPTextModelWithProjection, prior_tokenizer: CLIPTokenizer, prior_scheduler: UnCLIPScheduler, prior_image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, prior_prior=prior_prior, prior_image_encoder=prior_image_encoder, prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, prior_scheduler=prior_scheduler, prior_image_processor=prior_image_processor, ) self.prior_pipe = KandinskyV22PriorPipeline( prior=prior_prior, image_encoder=prior_image_encoder, text_encoder=prior_text_encoder, tokenizer=prior_tokenizer, scheduler=prior_scheduler, image_processor=prior_image_processor, ) self.decoder_pipe = KandinskyV22Pipeline( unet=unet, scheduler=scheduler, movq=movq, ) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.enable_model_cpu_offload() def set_progress_bar_config(self, **kwargs): self.prior_pipe.set_progress_bar_config(**kwargs) self.decoder_pipe.set_progress_bar_config(**kwargs) @torch.no_grad() @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, height: int = 512, width: int = 512, prior_guidance_scale: float = 4.0, prior_num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, return_dict: bool = True, prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. prior_num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. prior_callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference of the prior pipeline. The function is called with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. prior_callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your prior pipeline class. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference of the decoder pipeline. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ prior_outputs = self.prior_pipe( prompt=prompt, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, num_inference_steps=prior_num_inference_steps, generator=generator, latents=latents, guidance_scale=prior_guidance_scale, output_type="pt", return_dict=False, callback_on_step_end=prior_callback_on_step_end, callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, ) image_embeds = prior_outputs[0] negative_image_embeds = prior_outputs[1] prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: prompt = (image_embeds.shape[0] // len(prompt)) * prompt outputs = self.decoder_pipe( image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, width=width, height=height, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=guidance_scale, output_type=output_type, callback=callback, callback_steps=callback_steps, return_dict=return_dict, callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) self.maybe_free_model_hooks() return outputs class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline): """ Combined Pipeline for image-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. prior_prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. prior_image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. prior_text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. prior_tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior_scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. prior_image_processor ([`CLIPImageProcessor`]): A image_processor to be used to preprocess image from clip. """ model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq" _load_connected_pipes = True _exclude_from_cpu_offload = ["prior_prior"] def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, prior_prior: PriorTransformer, prior_image_encoder: CLIPVisionModelWithProjection, prior_text_encoder: CLIPTextModelWithProjection, prior_tokenizer: CLIPTokenizer, prior_scheduler: UnCLIPScheduler, prior_image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, prior_prior=prior_prior, prior_image_encoder=prior_image_encoder, prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, prior_scheduler=prior_scheduler, prior_image_processor=prior_image_processor, ) self.prior_pipe = KandinskyV22PriorPipeline( prior=prior_prior, image_encoder=prior_image_encoder, text_encoder=prior_text_encoder, tokenizer=prior_tokenizer, scheduler=prior_scheduler, image_processor=prior_image_processor, ) self.decoder_pipe = KandinskyV22Img2ImgPipeline( unet=unet, scheduler=scheduler, movq=movq, ) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.enable_model_cpu_offload() def set_progress_bar_config(self, **kwargs): self.prior_pipe.set_progress_bar_config(**kwargs) self.decoder_pipe.set_progress_bar_config(**kwargs) @torch.no_grad() @replace_example_docstring(IMAGE2IMAGE_EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], negative_prompt: Optional[Union[str, List[str]]] = None, num_inference_steps: int = 100, guidance_scale: float = 4.0, strength: float = 0.3, num_images_per_prompt: int = 1, height: int = 512, width: int = 512, prior_guidance_scale: float = 4.0, prior_num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, return_dict: bool = True, prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. strength (`float`, *optional*, defaults to 0.3): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. prior_num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ prior_outputs = self.prior_pipe( prompt=prompt, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, num_inference_steps=prior_num_inference_steps, generator=generator, latents=latents, guidance_scale=prior_guidance_scale, output_type="pt", return_dict=False, callback_on_step_end=prior_callback_on_step_end, callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, ) image_embeds = prior_outputs[0] negative_image_embeds = prior_outputs[1] prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt image = [image] if isinstance(prompt, PIL.Image.Image) else image if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: prompt = (image_embeds.shape[0] // len(prompt)) * prompt if ( isinstance(image, (list, tuple)) and len(image) < image_embeds.shape[0] and image_embeds.shape[0] % len(image) == 0 ): image = (image_embeds.shape[0] // len(image)) * image outputs = self.decoder_pipe( image=image, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, width=width, height=height, strength=strength, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=guidance_scale, output_type=output_type, callback=callback, callback_steps=callback_steps, return_dict=return_dict, callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) self.maybe_free_model_hooks() return outputs class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline): """ Combined Pipeline for inpainting generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. prior_prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. prior_image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. prior_text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. prior_tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior_scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. prior_image_processor ([`CLIPImageProcessor`]): A image_processor to be used to preprocess image from clip. """ model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq" _load_connected_pipes = True _exclude_from_cpu_offload = ["prior_prior"] def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, prior_prior: PriorTransformer, prior_image_encoder: CLIPVisionModelWithProjection, prior_text_encoder: CLIPTextModelWithProjection, prior_tokenizer: CLIPTokenizer, prior_scheduler: UnCLIPScheduler, prior_image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, prior_prior=prior_prior, prior_image_encoder=prior_image_encoder, prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, prior_scheduler=prior_scheduler, prior_image_processor=prior_image_processor, ) self.prior_pipe = KandinskyV22PriorPipeline( prior=prior_prior, image_encoder=prior_image_encoder, text_encoder=prior_text_encoder, tokenizer=prior_tokenizer, scheduler=prior_scheduler, image_processor=prior_image_processor, ) self.decoder_pipe = KandinskyV22InpaintPipeline( unet=unet, scheduler=scheduler, movq=movq, ) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.enable_model_cpu_offload() def set_progress_bar_config(self, **kwargs): self.prior_pipe.set_progress_bar_config(**kwargs) self.decoder_pipe.set_progress_bar_config(**kwargs) @torch.no_grad() @replace_example_docstring(INPAINT_EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], mask_image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], negative_prompt: Optional[Union[str, List[str]]] = None, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, height: int = 512, width: int = 512, prior_guidance_scale: float = 4.0, prior_num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. mask_image (`np.array`): Tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. prior_num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. prior_callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. prior_callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ prior_kwargs = {} if kwargs.get("prior_callback", None) is not None: prior_kwargs["callback"] = kwargs.pop("prior_callback") deprecate( "prior_callback", "1.0.0", "Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", ) if kwargs.get("prior_callback_steps", None) is not None: deprecate( "prior_callback_steps", "1.0.0", "Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", ) prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps") prior_outputs = self.prior_pipe( prompt=prompt, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, num_inference_steps=prior_num_inference_steps, generator=generator, latents=latents, guidance_scale=prior_guidance_scale, output_type="pt", return_dict=False, callback_on_step_end=prior_callback_on_step_end, callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, **prior_kwargs, ) image_embeds = prior_outputs[0] negative_image_embeds = prior_outputs[1] prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt image = [image] if isinstance(prompt, PIL.Image.Image) else image mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: prompt = (image_embeds.shape[0] // len(prompt)) * prompt if ( isinstance(image, (list, tuple)) and len(image) < image_embeds.shape[0] and image_embeds.shape[0] % len(image) == 0 ): image = (image_embeds.shape[0] // len(image)) * image if ( isinstance(mask_image, (list, tuple)) and len(mask_image) < image_embeds.shape[0] and image_embeds.shape[0] % len(mask_image) == 0 ): mask_image = (image_embeds.shape[0] // len(mask_image)) * mask_image outputs = self.decoder_pipe( image=image, mask_image=mask_image, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, width=width, height=height, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=guidance_scale, output_type=output_type, return_dict=return_dict, callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, **kwargs, ) self.maybe_free_model_hooks() return outputs ================================================ FILE: src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Union import torch from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDPMScheduler from ...utils import ( logging, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> import numpy as np >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline >>> from transformers import pipeline >>> from diffusers.utils import load_image >>> def make_hint(image, depth_estimator): ... image = depth_estimator(image)["depth"] ... image = np.array(image) ... image = image[:, :, None] ... image = np.concatenate([image, image, image], axis=2) ... detected_map = torch.from_numpy(image).float() / 255.0 ... hint = detected_map.permute(2, 0, 1) ... return hint >>> depth_estimator = pipeline("depth-estimation") >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior = pipe_prior.to("cuda") >>> pipe = KandinskyV22ControlnetPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> img = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ).resize((768, 768)) >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") >>> prompt = "A robot, 4k photo" >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" >>> generator = torch.Generator(device="cuda").manual_seed(43) >>> image_emb, zero_image_emb = pipe_prior( ... prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator ... ).to_tuple() >>> images = pipe( ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... hint=hint, ... num_inference_steps=50, ... generator=generator, ... height=768, ... width=768, ... ).images >>> images[0].save("robot_cat.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor class KandinskyV22ControlnetPipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ model_cpu_offload_seq = "unet->movq" def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, image_embeds: Union[torch.Tensor, List[torch.Tensor]], negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], hint: torch.Tensor, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. hint (`torch.Tensor`): The controlnet condition. image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if isinstance(hint, list): hint = torch.cat(hint, dim=0) batch_size = image_embeds.shape[0] * num_images_per_prompt if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) hint = hint.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=self.unet.dtype, device=device ) hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps num_channels_latents = self.movq.config.latent_channels height, width = downscale_height_and_width(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), image_embeds.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] # Offload all models self.maybe_free_model_hooks() if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Union import numpy as np import PIL.Image import torch from PIL import Image from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDPMScheduler from ...utils import ( logging, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> import numpy as np >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline >>> from transformers import pipeline >>> from diffusers.utils import load_image >>> def make_hint(image, depth_estimator): ... image = depth_estimator(image)["depth"] ... image = np.array(image) ... image = image[:, :, None] ... image = np.concatenate([image, image, image], axis=2) ... detected_map = torch.from_numpy(image).float() / 255.0 ... hint = detected_map.permute(2, 0, 1) ... return hint >>> depth_estimator = pipeline("depth-estimation") >>> pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior = pipe_prior.to("cuda") >>> pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> img = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ).resize((768, 768)) >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") >>> prompt = "A robot, 4k photo" >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" >>> generator = torch.Generator(device="cuda").manual_seed(43) >>> img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator) >>> negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator) >>> images = pipe( ... image=img, ... strength=0.5, ... image_embeds=img_emb.image_embeds, ... negative_image_embeds=negative_emb.image_embeds, ... hint=hint, ... num_inference_steps=50, ... generator=generator, ... height=768, ... width=768, ... ).images >>> images[0].save("robot_cat.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image def prepare_image(pil_image, w=512, h=512): pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) arr = np.array(pil_image.convert("RGB")) arr = arr.astype(np.float32) / 127.5 - 1 arr = np.transpose(arr, [2, 0, 1]) image = torch.from_numpy(arr).unsqueeze(0) return image class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): """ Pipeline for image-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ model_cpu_offload_seq = "unet->movq" def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2_img2img.KandinskyV22Img2ImgPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.movq.encode(image).latent_dist.sample(generator) init_latents = self.movq.config.scaling_factor * init_latents init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @torch.no_grad() def __call__( self, image_embeds: Union[torch.Tensor, List[torch.Tensor]], image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], hint: torch.Tensor, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, strength: float = 0.3, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. hint (`torch.Tensor`): The controlnet condition. negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if isinstance(hint, list): hint = torch.cat(hint, dim=0) batch_size = image_embeds.shape[0] if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) hint = hint.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=self.unet.dtype, device=device ) hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) if not isinstance(image, list): image = [image] if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" ) image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = image.to(dtype=image_embeds.dtype, device=device) latents = self.movq.encode(image)["latents"] latents = latents.repeat_interleave(num_images_per_prompt, dim=0) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) height, width = downscale_height_and_width(height, width, self.movq_scale_factor) latents = self.prepare_latents( latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator ) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] # Offload all models self.maybe_free_model_hooks() if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from PIL import Image from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDPMScheduler from ...utils import deprecate, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Img2ImgPipeline, KandinskyV22PriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "A red cartoon frog, 4k" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyV22Img2ImgPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/frog.png" ... ) >>> image = pipe( ... image=init_image, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... strength=0.2, ... ).images >>> image[0].save("red_frog.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image def prepare_image(pil_image, w=512, h=512): pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) arr = np.array(pil_image.convert("RGB")) arr = arr.astype(np.float32) / 127.5 - 1 arr = np.transpose(arr, [2, 0, 1]) image = torch.from_numpy(arr).unsqueeze(0) return image class KandinskyV22Img2ImgPipeline(DiffusionPipeline): """ Pipeline for image-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ model_cpu_offload_seq = "unet->movq" _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds"] def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.movq.encode(image).latent_dist.sample(generator) init_latents = self.movq.config.scaling_factor * init_latents init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() def __call__( self, image_embeds: Union[torch.Tensor, List[torch.Tensor]], image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]], negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, strength: float = 0.3, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): """ Function invoked when calling the pipeline for generation. Args: image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) device = self._execution_device self._guidance_scale = guidance_scale if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) batch_size = image_embeds.shape[0] if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if self.do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=self.unet.dtype, device=device ) if not isinstance(image, list): image = [image] if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" ) image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = image.to(dtype=image_embeds.dtype, device=device) latents = self.movq.encode(image)["latents"] latents = latents.repeat_interleave(num_images_per_prompt, dim=0) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) height, width = downscale_height_and_width(height, width, self.movq_scale_factor) latents = self.prepare_latents( latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator ) self._num_timesteps = len(timesteps) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if self.do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) image_embeds = callback_outputs.pop("image_embeds", image_embeds) negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds) if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}" ) if not output_type == "latent": # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) else: image = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from typing import Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from packaging import version from PIL import Image from ... import __version__ from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDPMScheduler from ...utils import deprecate, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> import numpy as np >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "a hat" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyV22InpaintPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> mask = np.zeros((768, 768), dtype=np.float32) >>> mask[:250, 250:-250] = 1 >>> out = pipe( ... image=init_image, ... mask_image=mask, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ) >>> image = out.images[0] >>> image.save("cat_with_hat.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask def prepare_mask(masks): prepared_masks = [] for mask in masks: old_mask = deepcopy(mask) for i in range(mask.shape[1]): for j in range(mask.shape[2]): if old_mask[0][i][j] == 1: continue if i != 0: mask[:, i - 1, j] = 0 if j != 0: mask[:, i, j - 1] = 0 if i != 0 and j != 0: mask[:, i - 1, j - 1] = 0 if i != mask.shape[1] - 1: mask[:, i + 1, j] = 0 if j != mask.shape[2] - 1: mask[:, i, j + 1] = 0 if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: mask[:, i + 1, j + 1] = 0 prepared_masks.append(mask) return torch.stack(prepared_masks, dim=0) # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask_and_masked_image def prepare_mask_and_masked_image(image, mask, height, width): r""" Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if image is None: raise ValueError("`image` input cannot be undefined.") if mask is None: raise ValueError("`mask_image` input cannot be undefined.") if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) mask = 1 - mask return mask, image class KandinskyV22InpaintPipeline(DiffusionPipeline): """ Pipeline for text-guided image inpainting using Kandinsky2.1 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ model_cpu_offload_seq = "unet->movq" _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds", "masked_image", "mask_image"] def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) self._warn_has_been_called = False # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() def __call__( self, image_embeds: Union[torch.Tensor, List[torch.Tensor]], image: Union[torch.Tensor, PIL.Image.Image], mask_image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]], height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): """ Function invoked when calling the pipeline for generation. Args: image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`np.array`): Tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ if not self._warn_has_been_called and version.parse(version.parse(__version__).base_version) < version.parse( "0.23.0.dev0" ): logger.warning( "Please note that the expected format of `mask_image` has recently been changed. " "Before diffusers == 0.19.0, Kandinsky Inpainting pipelines repainted black pixels and preserved black pixels. " "As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. " "This way, Kandinsky's masking behavior is aligned with Stable Diffusion. " "THIS means that you HAVE to invert the input mask to have the same behavior as before as explained in https://github.com/huggingface/diffusers/pull/4207. " "This warning will be surpressed after the first inference call and will be removed in diffusers>0.23.0" ) self._warn_has_been_called = True callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) self._guidance_scale = guidance_scale device = self._execution_device if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) batch_size = image_embeds.shape[0] * num_images_per_prompt if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if self.do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=self.unet.dtype, device=device ) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # preprocess image and mask mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) image = image.to(dtype=image_embeds.dtype, device=device) image = self.movq.encode(image)["latents"] mask_image = mask_image.to(dtype=image_embeds.dtype, device=device) image_shape = tuple(image.shape[-2:]) mask_image = F.interpolate( mask_image, image_shape, mode="nearest", ) mask_image = prepare_mask(mask_image) masked_image = image * mask_image mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) if self.do_classifier_free_guidance: mask_image = mask_image.repeat(2, 1, 1, 1) masked_image = masked_image.repeat(2, 1, 1, 1) num_channels_latents = self.movq.config.latent_channels height, width = downscale_height_and_width(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), image_embeds.dtype, device, generator, latents, self.scheduler, ) noise = torch.clone(latents) self._num_timesteps = len(timesteps) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) added_cond_kwargs = {"image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if self.do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] init_latents_proper = image[:1] init_mask = mask_image[:1] if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = init_mask * init_latents_proper + (1 - init_mask) * latents if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) image_embeds = callback_outputs.pop("image_embeds", image_embeds) negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds) masked_image = callback_outputs.pop("masked_image", masked_image) mask_image = callback_outputs.pop("mask_image", mask_image) if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # post-processing latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" ) if not output_type == "latent": image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) else: image = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py ================================================ from typing import Callable, Dict, List, Optional, Union import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer from ...schedulers import UnCLIPScheduler from ...utils import ( logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..kandinsky import KandinskyPriorPipelineOutput from ..pipeline_utils import DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline >>> import torch >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> image_emb, negative_image_emb = pipe_prior(prompt).to_tuple() >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ).images >>> image[0].save("cat.png") ``` """ EXAMPLE_INTERPOLATE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline >>> from diffusers.utils import load_image >>> import PIL >>> import torch >>> from torchvision import transforms >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> img1 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> img2 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/starry_night.jpeg" ... ) >>> images_texts = ["a cat", img1, img2] >>> weights = [0.3, 0.3, 0.4] >>> out = pipe_prior.interpolate(images_texts, weights) >>> pipe = KandinskyV22Pipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=out.image_embeds, ... negative_image_embeds=out.negative_image_embeds, ... height=768, ... width=768, ... num_inference_steps=50, ... ).images[0] >>> image.save("starry_cat.png") ``` """ class KandinskyV22PriorPipeline(DiffusionPipeline): """ Pipeline for generating image prior for Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. image_processor ([`CLIPImageProcessor`]): A image_processor to be used to preprocess image from clip. """ model_cpu_offload_seq = "text_encoder->image_encoder->prior" _exclude_from_cpu_offload = ["prior"] _callback_tensor_inputs = ["latents", "prompt_embeds", "text_encoder_hidden_states", "text_mask"] def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModelWithProjection, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: UnCLIPScheduler, image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, image_encoder=image_encoder, image_processor=image_processor, ) @torch.no_grad() @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) def interpolate( self, images_and_prompts: List[Union[str, PIL.Image.Image, torch.Tensor]], weights: List[float], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, negative_prior_prompt: Optional[str] = None, negative_prompt: str = "", guidance_scale: float = 4.0, device=None, ): """ Function invoked when using the prior pipeline for interpolation. Args: images_and_prompts (`List[Union[str, PIL.Image.Image, torch.Tensor]]`): list of prompts and images to guide the image generation. weights: (`List[float]`): list of weights for each condition in `images_and_prompts` num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. negative_prior_prompt (`str`, *optional*): The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ device = device or self.device if len(images_and_prompts) != len(weights): raise ValueError( f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" ) image_embeddings = [] for cond, weight in zip(images_and_prompts, weights): if isinstance(cond, str): image_emb = self( cond, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ).image_embeds.unsqueeze(0) elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): if isinstance(cond, PIL.Image.Image): cond = ( self.image_processor(cond, return_tensors="pt") .pixel_values[0] .unsqueeze(0) .to(dtype=self.image_encoder.dtype, device=device) ) image_emb = self.image_encoder(cond)["image_embeds"].repeat(num_images_per_prompt, 1).unsqueeze(0) else: raise ValueError( f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" ) image_embeddings.append(image_emb * weight) image_emb = torch.cat(image_embeddings).sum(dim=0) out_zero = self( negative_prompt, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ) zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed def get_zero_embed(self, batch_size=1, device=None): device = device or self.device zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( device=device, dtype=self.image_encoder.dtype ) zero_image_emb = self.image_encoder(zero_img)["image_embeds"] zero_image_emb = zero_image_emb.repeat(batch_size, 1) return zero_image_emb # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def guidance_scale(self): return self._guidance_scale @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, guidance_scale: float = 4.0, output_type: Optional[str] = "pt", # pt only return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] elif not isinstance(negative_prompt, list) and negative_prompt is not None: raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") # if the negative prompt is defined we double the batch size to # directly retrieve the negative prompt embedding if negative_prompt is not None: prompt = prompt + negative_prompt negative_prompt = 2 * negative_prompt device = self._execution_device batch_size = len(prompt) batch_size = batch_size * num_images_per_prompt self._guidance_scale = guidance_scale prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt ) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps embedding_dim = self.prior.config.embedding_dim latents = self.prepare_latents( (batch_size, embedding_dim), prompt_embeds.dtype, device, generator, latents, self.scheduler, ) self._num_timesteps = len(timesteps) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding if self.do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + self.guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == timesteps.shape[0]: prev_timestep = None else: prev_timestep = timesteps[i + 1] latents = self.scheduler.step( predicted_image_embedding, timestep=t, sample=latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) text_encoder_hidden_states = callback_outputs.pop( "text_encoder_hidden_states", text_encoder_hidden_states ) text_mask = callback_outputs.pop("text_mask", text_mask) latents = self.prior.post_process_latents(latents) image_embeddings = latents # if negative prompt has been defined, we retrieve split the image embedding into two if negative_prompt is None: zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) else: image_embeddings, zero_embeds = image_embeddings.chunk(2) self.maybe_free_model_hooks() if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") if output_type == "np": image_embeddings = image_embeddings.cpu().numpy() zero_embeds = zero_embeds.cpu().numpy() if not return_dict: return (image_embeddings, zero_embeds) return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) ================================================ FILE: src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py ================================================ from typing import List, Optional, Union import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer from ...schedulers import UnCLIPScheduler from ...utils import ( logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..kandinsky import KandinskyPriorPipelineOutput from ..pipeline_utils import DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> img = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> image_emb, nagative_image_emb = pipe_prior(prompt, image=img, strength=0.2).to_tuple() >>> pipe = KandinskyPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder, torch_dtype=torch.float16" ... ) >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... ).images >>> image[0].save("cat.png") ``` """ EXAMPLE_INTERPOLATE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22Pipeline >>> from diffusers.utils import load_image >>> import PIL >>> import torch >>> from torchvision import transforms >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> img1 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> img2 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/starry_night.jpeg" ... ) >>> images_texts = ["a cat", img1, img2] >>> weights = [0.3, 0.3, 0.4] >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) >>> pipe = KandinskyV22Pipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=150, ... ).images[0] >>> image.save("starry_cat.png") ``` """ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline): """ Pipeline for generating image prior for Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ model_cpu_offload_seq = "text_encoder->image_encoder->prior" _exclude_from_cpu_offload = ["prior"] def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModelWithProjection, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: UnCLIPScheduler, image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, image_encoder=image_encoder, image_processor=image_processor, ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start @torch.no_grad() @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) def interpolate( self, images_and_prompts: List[Union[str, PIL.Image.Image, torch.Tensor]], weights: List[float], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, negative_prior_prompt: Optional[str] = None, negative_prompt: str = "", guidance_scale: float = 4.0, device=None, ): """ Function invoked when using the prior pipeline for interpolation. Args: images_and_prompts (`List[Union[str, PIL.Image.Image, torch.Tensor]]`): list of prompts and images to guide the image generation. weights: (`List[float]`): list of weights for each condition in `images_and_prompts` num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. negative_prior_prompt (`str`, *optional*): The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ device = device or self.device if len(images_and_prompts) != len(weights): raise ValueError( f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" ) image_embeddings = [] for cond, weight in zip(images_and_prompts, weights): if isinstance(cond, str): image_emb = self( cond, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ).image_embeds.unsqueeze(0) elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): image_emb = self._encode_image( cond, device=device, num_images_per_prompt=num_images_per_prompt ).unsqueeze(0) else: raise ValueError( f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" ) image_embeddings.append(image_emb * weight) image_emb = torch.cat(image_embeddings).sum(dim=0) return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=torch.randn_like(image_emb)) def _encode_image( self, image: Union[torch.Tensor, List[PIL.Image.Image]], device, num_images_per_prompt, ): if not isinstance(image, torch.Tensor): image = self.image_processor(image, return_tensors="pt").pixel_values.to( dtype=self.image_encoder.dtype, device=device ) image_emb = self.image_encoder(image)["image_embeds"] # B, D image_emb = image_emb.repeat_interleave(num_images_per_prompt, dim=0) image_emb.to(device=device) return image_emb def prepare_latents(self, emb, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): emb = emb.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt init_latents = emb if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed def get_zero_embed(self, batch_size=1, device=None): device = device or self.device zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( device=device, dtype=self.image_encoder.dtype ) zero_image_emb = self.image_encoder(zero_img)["image_embeds"] zero_image_emb = zero_image_emb.repeat(batch_size, 1) return zero_image_emb # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], strength: float = 0.3, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, guidance_scale: float = 4.0, output_type: Optional[str] = "pt", # pt only return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `emb`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. emb (`torch.Tensor`): The image embedding. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] elif not isinstance(negative_prompt, list) and negative_prompt is not None: raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") # if the negative prompt is defined we double the batch size to # directly retrieve the negative prompt embedding if negative_prompt is not None: prompt = prompt + negative_prompt negative_prompt = 2 * negative_prompt device = self._execution_device batch_size = len(prompt) batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if not isinstance(image, List): image = [image] if isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) if isinstance(image, torch.Tensor) and image.ndim == 2: # allow user to pass image_embeds directly image_embeds = image.repeat_interleave(num_images_per_prompt, dim=0) elif isinstance(image, torch.Tensor) and image.ndim != 4: raise ValueError( f" if pass `image` as pytorch tensor, or a list of pytorch tensor, please make sure each tensor has shape [batch_size, channels, height, width], currently {image[0].unsqueeze(0).shape}" ) else: image_embeds = self._encode_image(image, device, num_images_per_prompt) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) latents = image_embeds timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size) latents = self.prepare_latents( latents, latent_timestep, batch_size // num_images_per_prompt, num_images_per_prompt, prompt_embeds.dtype, device, generator, ) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == timesteps.shape[0]: prev_timestep = None else: prev_timestep = timesteps[i + 1] latents = self.scheduler.step( predicted_image_embedding, timestep=t, sample=latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample latents = self.prior.post_process_latents(latents) image_embeddings = latents # if negative prompt has been defined, we retrieve split the image embedding into two if negative_prompt is None: zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) else: image_embeddings, zero_embeds = image_embeddings.chunk(2) self.maybe_free_model_hooks() if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") if output_type == "np": image_embeddings = image_embeddings.cpu().numpy() zero_embeds = zero_embeds.cpu().numpy() if not return_dict: return (image_embeddings, zero_embeds) return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) ================================================ FILE: src/diffusers/pipelines/kandinsky3/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_kandinsky3"] = ["Kandinsky3Pipeline"] _import_structure["pipeline_kandinsky3_img2img"] = ["Kandinsky3Img2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_kandinsky3 import Kandinsky3Pipeline from .pipeline_kandinsky3_img2img import Kandinsky3Img2ImgPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py ================================================ #!/usr/bin/env python3 import argparse import fnmatch from safetensors.torch import load_file from diffusers import Kandinsky3UNet MAPPING = { "to_time_embed.1": "time_embedding.linear_1", "to_time_embed.3": "time_embedding.linear_2", "in_layer": "conv_in", "out_layer.0": "conv_norm_out", "out_layer.2": "conv_out", "down_samples": "down_blocks", "up_samples": "up_blocks", "projection_lin": "encoder_hid_proj.projection_linear", "projection_ln": "encoder_hid_proj.projection_norm", "feature_pooling": "add_time_condition", "to_query": "to_q", "to_key": "to_k", "to_value": "to_v", "output_layer": "to_out.0", "self_attention_block": "attentions.0", } DYNAMIC_MAP = { "resnet_attn_blocks.*.0": "resnets_in.*", "resnet_attn_blocks.*.1": ("attentions.*", 1), "resnet_attn_blocks.*.2": "resnets_out.*", } # MAPPING = {} def convert_state_dict(unet_state_dict): """ Args: Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model. unet_model (torch.nn.Module): The original U-Net model. unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with. Returns: OrderedDict: The converted state dictionary. """ # Example of renaming logic (this will vary based on your model's architecture) converted_state_dict = {} for key in unet_state_dict: new_key = key for pattern, new_pattern in MAPPING.items(): new_key = new_key.replace(pattern, new_pattern) for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items(): has_matched = False if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched: star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1]) if isinstance(dyn_new_pattern, tuple): new_star = star + dyn_new_pattern[-1] dyn_new_pattern = dyn_new_pattern[0] else: new_star = star pattern = dyn_pattern.replace("*", str(star)) new_pattern = dyn_new_pattern.replace("*", str(new_star)) new_key = new_key.replace(pattern, new_pattern) has_matched = True converted_state_dict[new_key] = unet_state_dict[key] return converted_state_dict def main(model_path, output_path): # Load your original U-Net model unet_state_dict = load_file(model_path) # Initialize your Kandinsky3UNet model config = {} # Convert the state dict converted_state_dict = convert_state_dict(unet_state_dict) unet = Kandinsky3UNet(config) unet.load_state_dict(converted_state_dict) unet.save_pretrained(output_path) print(f"Converted model saved to {output_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format") parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model") parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model") args = parser.parse_args() main(args.model_path, args.output_path) ================================================ FILE: src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py ================================================ from typing import Callable, Dict, List, Optional, Union import torch from transformers import T5EncoderModel, T5Tokenizer from ...loaders import StableDiffusionLoraLoaderMixin from ...models import Kandinsky3UNet, VQModel from ...schedulers import DDPMScheduler from ...utils import ( deprecate, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import AutoPipelineForText2Image >>> import torch >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background." >>> generator = torch.Generator(device="cpu").manual_seed(0) >>> image = pipe(prompt, num_inference_steps=25, generator=generator).images[0] ``` """ def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor class Kandinsky3Pipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): model_cpu_offload_seq = "text_encoder->unet->movq" _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "negative_attention_mask", "attention_mask", ] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: Kandinsky3UNet, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq ) def process_embeds(self, embeddings, attention_mask, cut_context): if cut_context: embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0]) max_seq_length = attention_mask.sum(-1).max() + 1 embeddings = embeddings[:, :max_seq_length] attention_mask = attention_mask[:, :max_seq_length] return embeddings, attention_mask @torch.no_grad() def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, _cut_context=False, attention_mask: Optional[torch.Tensor] = None, negative_attention_mask: Optional[torch.Tensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. negative_attention_mask (`torch.Tensor`, *optional*): Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] max_length = 128 if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids, attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds, attention_mask = self.process_embeds(prompt_embeds, attention_mask, _cut_context) prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2) if self.text_encoder is not None: dtype = self.text_encoder.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) attention_mask = attention_mask.repeat(num_images_per_prompt, 1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt if negative_prompt is not None: uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=128, truncation=True, return_attention_mask=True, return_tensors="pt", ) text_input_ids = uncond_input.input_ids.to(device) negative_attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( text_input_ids, attention_mask=negative_attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]] negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]] negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2) else: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_attention_mask = torch.zeros_like(attention_mask) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) if negative_prompt_embeds.shape != prompt_embeds.shape: negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None negative_attention_mask = None return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, attention_mask=None, negative_attention_mask=None, ): if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if negative_prompt_embeds is not None and negative_attention_mask is None: raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`") if negative_prompt_embeds is not None and negative_attention_mask is not None: if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape: raise ValueError( "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but" f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`" f" {negative_attention_mask.shape}." ) if prompt_embeds is not None and attention_mask is None: raise ValueError("Please provide `attention_mask` along with `prompt_embeds`") if prompt_embeds is not None and attention_mask is not None: if prompt_embeds.shape[:2] != attention_mask.shape: raise ValueError( "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`" f" {attention_mask.shape}." ) @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, num_inference_steps: int = 25, guidance_scale: float = 3.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, height: Optional[int] = 1024, width: Optional[int] = 1024, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, negative_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, latents=None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 3.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size): The width in pixels of the generated image. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. negative_attention_mask (`torch.Tensor`, *optional*): Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) cut_context = True device = self._execution_device # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, attention_mask, negative_attention_mask, ) self._guidance_scale = guidance_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # 3. Encode input prompt prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( prompt, self.do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, _cut_context=cut_context, attention_mask=attention_mask, negative_attention_mask=negative_attention_mask, ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latents height, width = downscale_height_and_width(height, width, 8) latents = self.prepare_latents( (batch_size * num_images_per_prompt, 4, height, width), prompt_embeds.dtype, device, generator, latents, self.scheduler, ) if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, encoder_attention_mask=attention_mask, return_dict=False, )[0] if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, ).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) attention_mask = callback_outputs.pop("attention_mask", attention_mask) negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # post-processing if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" ) if not output_type == "latent": image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) else: image = latents self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py ================================================ import inspect from typing import Callable, Dict, List, Optional, Union import numpy as np import PIL import PIL.Image import torch from transformers import T5EncoderModel, T5Tokenizer from ...loaders import StableDiffusionLoraLoaderMixin from ...models import Kandinsky3UNet, VQModel from ...schedulers import DDPMScheduler from ...utils import ( deprecate, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import AutoPipelineForImage2Image >>> from diffusers.utils import load_image >>> import torch >>> pipe = AutoPipelineForImage2Image.from_pretrained( ... "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "A painting of the inside of a subway train with tiny raccoons." >>> image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png" ... ) >>> generator = torch.Generator(device="cpu").manual_seed(0) >>> image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0] ``` """ def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor def prepare_image(pil_image): arr = np.array(pil_image.convert("RGB")) arr = arr.astype(np.float32) / 127.5 - 1 arr = np.transpose(arr, [2, 0, 1]) image = torch.from_numpy(arr).unsqueeze(0) return image class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): model_cpu_offload_seq = "text_encoder->movq->unet->movq" _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "negative_attention_mask", "attention_mask", ] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: Kandinsky3UNet, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start def _process_embeds(self, embeddings, attention_mask, cut_context): # return embeddings, attention_mask if cut_context: embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0]) max_seq_length = attention_mask.sum(-1).max() + 1 embeddings = embeddings[:, :max_seq_length] attention_mask = attention_mask[:, :max_seq_length] return embeddings, attention_mask @torch.no_grad() def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, _cut_context=False, attention_mask: Optional[torch.Tensor] = None, negative_attention_mask: Optional[torch.Tensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. negative_attention_mask (`torch.Tensor`, *optional*): Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] max_length = 128 if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids, attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds, attention_mask = self._process_embeds(prompt_embeds, attention_mask, _cut_context) prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2) if self.text_encoder is not None: dtype = self.text_encoder.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) attention_mask = attention_mask.repeat(num_images_per_prompt, 1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt if negative_prompt is not None: uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=128, truncation=True, return_attention_mask=True, return_tensors="pt", ) text_input_ids = uncond_input.input_ids.to(device) negative_attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( text_input_ids, attention_mask=negative_attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]] negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]] negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2) else: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_attention_mask = torch.zeros_like(attention_mask) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) if negative_prompt_embeds.shape != prompt_embeds.shape: negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None negative_attention_mask = None return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.movq.encode(image).latent_dist.sample(generator) init_latents = self.movq.config.scaling_factor * init_latents init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, attention_mask=None, negative_attention_mask=None, ): if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if negative_prompt_embeds is not None and negative_attention_mask is None: raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`") if negative_prompt_embeds is not None and negative_attention_mask is not None: if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape: raise ValueError( "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but" f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`" f" {negative_attention_mask.shape}." ) if prompt_embeds is not None and attention_mask is None: raise ValueError("Please provide `attention_mask` along with `prompt_embeds`") if prompt_embeds is not None and attention_mask is not None: if prompt_embeds.shape[:2] != attention_mask.shape: raise ValueError( "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`" f" {attention_mask.shape}." ) @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None, strength: float = 0.3, num_inference_steps: int = 25, guidance_scale: float = 3.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, negative_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 3.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. negative_attention_mask (`torch.Tensor`, *optional*): Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) cut_context = True # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, attention_mask, negative_attention_mask, ) self._guidance_scale = guidance_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( prompt, self.do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, _cut_context=cut_context, attention_mask=attention_mask, negative_attention_mask=negative_attention_mask, ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() if not isinstance(image, list): image = [image] if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" ) image = torch.cat([prepare_image(i) for i in image], dim=0) image = image.to(dtype=prompt_embeds.dtype, device=device) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # 5. Prepare latents latents = self.movq.encode(image)["latents"] latents = latents.repeat_interleave(num_images_per_prompt, dim=0) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latents = self.prepare_latents( latents, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, encoder_attention_mask=attention_mask, )[0] if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, ).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) attention_mask = callback_outputs.pop("attention_mask", attention_mask) negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # post-processing if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" ) if not output_type == "latent": image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) else: image = latents self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kolors/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_sentencepiece_available, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects)) else: _import_structure["pipeline_kolors"] = ["KolorsPipeline"] _import_structure["pipeline_kolors_img2img"] = ["KolorsImg2ImgPipeline"] _import_structure["text_encoder"] = ["ChatGLMModel"] _import_structure["tokenizer"] = ["ChatGLMTokenizer"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_and_sentencepiece_objects import * else: from .pipeline_kolors import KolorsPipeline from .pipeline_kolors_img2img import KolorsImg2ImgPipeline from .text_encoder import ChatGLMModel from .tokenizer import ChatGLMTokenizer else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/kolors/pipeline_kolors.py ================================================ # Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import KolorsPipelineOutput from .text_encoder import ChatGLMModel from .tokenizer import ChatGLMTokenizer if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import KolorsPipeline >>> pipe = KolorsPipeline.from_pretrained( ... "Kwai-Kolors/Kolors-diffusers", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = ( ... "A photo of a ladybug, macro, zoom, high quality, film, holding a wooden sign with the text 'KOLORS'" ... ) >>> image = pipe(prompt).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin): r""" Pipeline for text-to-image generation using Kolors. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`ChatGLMModel`]): Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b). tokenizer (`ChatGLMTokenizer`): Tokenizer of class [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `Kwai-Kolors/Kolors-diffusers`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = [ "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: ChatGLMModel, tokenizer: ChatGLMTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = False, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size def encode_prompt( self, prompt, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 256, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. """ # from IPython import embed; embed(); exit() device = device or self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer] text_encoders = [self.text_encoder] if prompt_embeds is None: prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_tensors="pt", ).to(device) output = text_encoder( input_ids=text_inputs["input_ids"], attention_mask=text_inputs["attention_mask"], position_ids=text_inputs["position_ids"], output_hidden_states=True, ) # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] # clone to have a contiguous tensor prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = prompt_embeds_list[0] # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt negative_prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_sequence_length, truncation=True, return_tensors="pt", ).to(device) output = text_encoder( input_ids=uncond_input["input_ids"], attention_mask=uncond_input["attention_mask"], position_ids=uncond_input["position_ids"], output_hidden_states=True, ) # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] # clone to have a contiguous tensor negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view( batch_size * num_images_per_prompt, seq_len, -1 ) negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = negative_prompt_embeds_list[0] bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, num_inference_steps, height, width, negative_prompt=None, prompt_embeds=None, pooled_prompt_embeds=None, negative_prompt_embeds=None, negative_pooled_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: raise ValueError( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) if max_sequence_length is not None and max_sequence_length > 256: raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, num_inference_steps, height, width, negative_prompt, prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 8.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # unscale/denormalize the latents latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return KolorsPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py ================================================ # Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import KolorsPipelineOutput from .text_encoder import ChatGLMModel from .tokenizer import ChatGLMTokenizer if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import KolorsImg2ImgPipeline >>> from diffusers.utils import load_image >>> pipe = KolorsImg2ImgPipeline.from_pretrained( ... "Kwai-Kolors/Kolors-diffusers", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> url = ( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/bunny_source.png" ... ) >>> init_image = load_image(url) >>> prompt = "high quality image of a capybara wearing sunglasses. In the background of the image there are trees, poles, grass and other objects. At the bottom of the object there is the road., 8k, highly detailed." >>> image = pipe(prompt, image=init_image).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin): r""" Pipeline for text-to-image generation using Kolors. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`ChatGLMModel`]): Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b). tokenizer (`ChatGLMTokenizer`): Tokenizer of class [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `Kwai-Kolors/Kolors-diffusers`. """ model_cpu_offload_seq = "text_encoder->image_encoder-unet->vae" _optional_components = [ "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: ChatGLMModel, tokenizer: ChatGLMTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = False, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt def encode_prompt( self, prompt, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 256, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. """ # from IPython import embed; embed(); exit() device = device or self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer] text_encoders = [self.text_encoder] if prompt_embeds is None: prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_tensors="pt", ).to(device) output = text_encoder( input_ids=text_inputs["input_ids"], attention_mask=text_inputs["attention_mask"], position_ids=text_inputs["position_ids"], output_hidden_states=True, ) # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] # clone to have a contiguous tensor prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = prompt_embeds_list[0] # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt negative_prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_sequence_length, truncation=True, return_tensors="pt", ).to(device) output = text_encoder( input_ids=uncond_input["input_ids"], attention_mask=uncond_input["attention_mask"], position_ids=uncond_input["position_ids"], output_hidden_states=True, ) # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] # clone to have a contiguous tensor negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view( batch_size * num_images_per_prompt, seq_len, -1 ) negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = negative_prompt_embeds_list[0] bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, num_inference_steps, height, width, negative_prompt=None, prompt_embeds=None, pooled_prompt_embeds=None, negative_prompt_embeds=None, negative_pooled_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: raise ValueError( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) if max_sequence_length is not None and max_sequence_length > 256: raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) else: t_start = 0 timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] # Strength is irrelevant if we directly request a timestep to start at; # that is, strength is determined by the denoising_start instead. if denoising_start is not None: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() if self.scheduler.order == 2 and num_inference_steps % 2 == 0: # if the scheduler is a 2nd order scheduler we might have to do +1 # because `num_inference_steps` might be even given that every timestep # (except the highest one) is duplicated. If `num_inference_steps` is even it would # mean that we cut the timesteps in the middle of the denoising step # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler num_inference_steps = num_inference_steps + 1 # because t_n+1 >= t_n, we slice the timesteps starting from the end timesteps = timesteps[-num_inference_steps:] return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents def prepare_latents( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) latents_mean = latents_std = None if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") torch.cuda.empty_cache() image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() self.vae.to(dtype=torch.float32) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) init_latents = init_latents.to(dtype) if latents_mean is not None and latents_std is not None: latents_mean = latents_mean.to(device=device, dtype=dtype) latents_std = latents_std.to(device=device, dtype=dtype) init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std else: init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) if add_noise: shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_start(self): return self._denoising_start @property def denoising_end(self): return self._denoising_end @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, strength: float = 0.3, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): The image(s) to modify with the pipeline. strength (`float`, *optional*, defaults to 0.3): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of `denoising_start` being declared as an integer, the value of `strength` will be ignored. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, strength, num_inference_steps, height, width, negative_prompt, prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._denoising_start = denoising_start self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. Prepare timesteps def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device, denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) add_noise = True if self.denoising_start is None else False # 6. Prepare latent variables if latents is None: latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, add_noise, ) # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) height, width = latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 8. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 9. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 9.1 Apply denoising_end if ( self.denoising_end is not None and self.denoising_start is not None and denoising_value_valid(self.denoising_end) and denoising_value_valid(self.denoising_start) and self.denoising_start >= self.denoising_end ): raise ValueError( f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + f" {self.denoising_end} when using type float." ) elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 9.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # unscale/denormalize the latents latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return KolorsPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/kolors/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Union import numpy as np import PIL.Image from ...utils import BaseOutput @dataclass class KolorsPipelineOutput(BaseOutput): """ Output class for Kolors pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ images: Union[List[PIL.Image.Image], np.ndarray] ================================================ FILE: src/diffusers/pipelines/kolors/text_encoder.py ================================================ # Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from torch.nn import LayerNorm from torch.nn.utils import skip_init from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast from ...utils import logging logger = logging.get_logger(__name__) class ChatGLMConfig(PretrainedConfig): model_type = "chatglm" def __init__( self, num_layers=28, padded_vocab_size=65024, hidden_size=4096, ffn_hidden_size=13696, kv_channels=128, num_attention_heads=32, seq_length=2048, hidden_dropout=0.0, classifier_dropout=None, attention_dropout=0.0, layernorm_epsilon=1e-5, rmsnorm=True, apply_residual_connection_post_layernorm=False, post_layer_norm=True, add_bias_linear=False, add_qkv_bias=False, bias_dropout_fusion=True, multi_query_attention=False, multi_query_group_num=1, apply_query_key_layer_scaling=True, attention_softmax_in_fp32=True, fp32_residual_connection=False, quantization_bit=0, pre_seq_len=None, prefix_projection=False, **kwargs, ): self.num_layers = num_layers self.vocab_size = padded_vocab_size self.padded_vocab_size = padded_vocab_size self.hidden_size = hidden_size self.ffn_hidden_size = ffn_hidden_size self.kv_channels = kv_channels self.num_attention_heads = num_attention_heads self.seq_length = seq_length self.hidden_dropout = hidden_dropout self.classifier_dropout = classifier_dropout self.attention_dropout = attention_dropout self.layernorm_epsilon = layernorm_epsilon self.rmsnorm = rmsnorm self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.post_layer_norm = post_layer_norm self.add_bias_linear = add_bias_linear self.add_qkv_bias = add_qkv_bias self.bias_dropout_fusion = bias_dropout_fusion self.multi_query_attention = multi_query_attention self.multi_query_group_num = multi_query_group_num self.apply_query_key_layer_scaling = apply_query_key_layer_scaling self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.fp32_residual_connection = fp32_residual_connection self.quantization_bit = quantization_bit self.pre_seq_len = pre_seq_len self.prefix_projection = prefix_projection super().__init__(**kwargs) class RMSNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) self.eps = eps def forward(self, hidden_states: torch.Tensor): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return (self.weight * hidden_states).to(input_dtype) def _config_to_kwargs(args): common_kwargs = { "dtype": args.torch_dtype, } return common_kwargs class CoreAttention(torch.nn.Module): def __init__(self, config: ChatGLMConfig, layer_number): super(CoreAttention, self).__init__() self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. self.hidden_size_per_partition = projection_size self.hidden_size_per_attention_head = projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.coeff = coeff self.attention_dropout = torch.nn.Dropout(config.attention_dropout) def forward(self, query_layer, key_layer, value_layer, attention_mask): pytorch_major_version = int(torch.__version__.split(".")[0]) if pytorch_major_version >= 2: query_layer, key_layer, value_layer = [ k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] ] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, is_causal=True ) else: if attention_mask is not None: attention_mask = ~attention_mask context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attention_mask ) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) else: # Raw attention scores # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting input tensor: [b * np, sq, sk] matmul_input_buffer = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=query_layer.device, ) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_input_buffer, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor), ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # =========================== # Attention probs and dropout # =========================== # attention scores and attention mask [b, np, sq, sk] if self.attention_softmax_in_fp32: attention_scores = attention_scores.float() if self.coeff is not None: attention_scores = attention_scores * self.coeff if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: attention_mask = torch.ones( output_size[0], 1, output_size[2], output_size[3], device=attention_scores.device, dtype=torch.bool ) attention_mask.tril_() attention_mask = ~attention_mask if attention_mask is not None: attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) attention_probs = F.softmax(attention_scores, dim=-1) attention_probs = attention_probs.type_as(value_layer) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer def split_tensor_along_last_dim( tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, ) -> List[torch.Tensor]: """Split a tensor along its last dimension. Arguments: tensor: input tensor. num_partitions: number of partitions to split the tensor contiguous_split_chunks: If True, make each chunk contiguous in memory. Returns: A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 last_dim_size = tensor.size()[last_dim] // num_partitions # Split. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) # Note: torch.split does not create contiguous tensors by default. if contiguous_split_chunks: return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list @torch.jit.script def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: # x: [sq, b, np, hn] sq, _b, np, _hn = x.size(0), x.size(1), x.size(2), x.size(3) rot_dim = rope_cache.shape[-2] * 2 x, x_pass = x[..., :rot_dim], x[..., rot_dim:] # truncate to support variable sizes rope_cache = rope_cache[:sq] xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) x_out2 = torch.stack( [ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], ], -1, ) x_out2 = x_out2.flatten(3) return torch.cat((x_out2, x_pass), dim=-1) class SelfAttention(torch.nn.Module): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(SelfAttention, self).__init__() self.layer_number = max(1, layer_number) self.projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads self.multi_query_attention = config.multi_query_attention self.qkv_hidden_size = 3 * self.projection_size if self.multi_query_attention: self.num_multi_query_groups_per_partition = config.multi_query_group_num self.qkv_hidden_size = ( self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num ) self.query_key_value = nn.Linear( config.hidden_size, self.qkv_hidden_size, bias=config.add_bias_linear or config.add_qkv_bias, device=device, **_config_to_kwargs(config), ) self.core_attention = CoreAttention(config, self.layer_number) # Output. self.dense = nn.Linear( self.projection_size, config.hidden_size, bias=config.add_bias_linear, device=device, **_config_to_kwargs(config), ) def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): if self.multi_query_attention: num_attention_heads = self.num_multi_query_groups_per_partition else: num_attention_heads = self.num_attention_heads_per_partition return torch.empty( inference_max_sequence_len, batch_size, num_attention_heads, self.hidden_size_per_attention_head, dtype=dtype, device=device, ) def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= # ===================== # Query, Key, and Value # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) if self.multi_query_attention: (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, ], dim=-1, ) query_layer = query_layer.view( query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) ) key_layer = key_layer.view( key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) ) value_layer = value_layer.view( value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) ) else: new_tensor_shape = mixed_x_layer.size()[:-1] + ( self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head, ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) # adjust key and value for inference if kv_cache is not None: cache_k, cache_v = kv_cache key_layer = torch.cat((cache_k, key_layer), dim=0) value_layer = torch.cat((cache_v, value_layer), dim=0) if use_cache: kv_cache = (key_layer, value_layer) else: kv_cache = None if self.multi_query_attention: key_layer = key_layer.unsqueeze(-2) key_layer = key_layer.expand( -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 ) key_layer = key_layer.contiguous().view( key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) ) value_layer = value_layer.unsqueeze(-2) value_layer = value_layer.expand( -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 ) value_layer = value_layer.contiguous().view( value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) ) # ================================== # core attention computation # ================================== context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # ================= # Output. [sq, b, h] # ================= output = self.dense(context_layer) return output, kv_cache class MLP(torch.nn.Module): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. """ def __init__(self, config: ChatGLMConfig, device=None): super(MLP, self).__init__() self.add_bias = config.add_bias_linear # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf self.dense_h_to_4h = nn.Linear( config.hidden_size, config.ffn_hidden_size * 2, bias=self.add_bias, device=device, **_config_to_kwargs(config), ) def swiglu(x): x = torch.chunk(x, 2, dim=-1) return F.silu(x[0]) * x[1] self.activation_func = swiglu # Project back to h. self.dense_4h_to_h = nn.Linear( config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config) ) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output = self.dense_4h_to_h(intermediate_parallel) return output class GLMBlock(torch.nn.Module): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(GLMBlock, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm self.fp32_residual_connection = config.fp32_residual_connection LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. self.input_layernorm = LayerNormFunc( config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype ) # Self attention. self.self_attention = SelfAttention(config, layer_number, device=device) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output self.post_attention_layernorm = LayerNormFunc( config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype ) # MLP self.mlp = MLP(config, device=device) def forward( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, ): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, kv_cache = self.self_attention( layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache ) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) layernorm_input = residual + layernorm_input # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. mlp_output = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) output = residual + output return output, kv_cache class GLMTransformer(torch.nn.Module): """Transformer class.""" def __init__(self, config: ChatGLMConfig, device=None): super(GLMTransformer, self).__init__() self.fp32_residual_connection = config.fp32_residual_connection self.post_layer_norm = config.post_layer_norm # Number of layers. self.num_layers = config.num_layers # Transformer layers. def build_layer(layer_number): return GLMBlock(config, layer_number, device=device) self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) if self.post_layer_norm: LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. self.final_layernorm = LayerNormFunc( config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype ) self.gradient_checkpointing = False def _get_layer(self, layer_number): return self.layers[layer_number] def forward( self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, use_cache: Optional[bool] = True, output_hidden_states: Optional[bool] = False, ): if not kv_caches: kv_caches = [None for _ in range(self.num_layers)] presents = () if use_cache else None if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False all_self_attentions = None all_hidden_states = () if output_hidden_states else None for index in range(self.num_layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer = self._get_layer(index) if self.gradient_checkpointing and self.training: layer_ret = torch.utils.checkpoint.checkpoint( layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache ) else: layer_ret = layer( hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index], use_cache=use_cache ) hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # Final layer norm. if self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) return hidden_states, presents, all_hidden_states, all_self_attentions class ChatGLMPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ is_parallelizable = False supports_gradient_checkpointing = True config_class = ChatGLMConfig base_model_prefix = "transformer" _no_split_modules = ["GLMBlock"] def _init_weights(self, module: nn.Module): """Initialize the weights.""" return def get_masks(self, input_ids, past_key_values, padding_mask=None): batch_size, seq_length = input_ids.shape full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) full_attention_mask.tril_() past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[0] if past_length: full_attention_mask = torch.cat( (torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1 ) if padding_mask is not None: full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) if not past_length and padding_mask is not None: full_attention_mask -= padding_mask.unsqueeze(-1) - 1 full_attention_mask = (full_attention_mask < 0.5).bool() full_attention_mask.unsqueeze_(1) return full_attention_mask def get_position_ids(self, input_ids, device): batch_size, seq_length = input_ids.shape position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) return position_ids def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, GLMTransformer): module.gradient_checkpointing = value def default_init(cls, *args, **kwargs): return cls(*args, **kwargs) class Embedding(torch.nn.Module): """Language model embeddings.""" def __init__(self, config: ChatGLMConfig, device=None): super(Embedding, self).__init__() self.hidden_size = config.hidden_size # Word embeddings (parallel). self.word_embeddings = nn.Embedding( config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device ) self.fp32_residual_connection = config.fp32_residual_connection def forward(self, input_ids): # Embeddings. words_embeddings = self.word_embeddings(input_ids) embeddings = words_embeddings # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. embeddings = embeddings.transpose(0, 1).contiguous() # If the input flag for fp32 residual connection is set, convert for float. if self.fp32_residual_connection: embeddings = embeddings.float() return embeddings class RotaryEmbedding(nn.Module): def __init__(self, dim, original_impl=False, device=None, dtype=None): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) self.register_buffer("inv_freq", inv_freq) self.dim = dim self.original_impl = original_impl def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000): """Enhanced Transformer with Rotary Position Embedding. Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ transformers/rope/__init__.py. MIT License: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. """ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) # Calculate the product of position index and $\theta_i$ idx_theta = torch.outer(seq_idx, theta).float() cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) # this is to mimic the behaviour of complex32, else we will get different results if dtype in (torch.float16, torch.bfloat16, torch.int8): cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() return cache def forward(self, max_seq_len, offset=0): return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device) class PrefixEncoder(torch.nn.Module): """ The torch.nn model to encode the prefix Input shape: (batch-size, prefix-length) Output shape: (batch-size, prefix-length, 2*layers*hidden) """ def __init__(self, config: ChatGLMConfig): super().__init__() self.prefix_projection = config.prefix_projection if self.prefix_projection: # Use a two-layer MLP to encode the prefix kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) self.trans = torch.nn.Sequential( torch.nn.Linear(kv_size, config.hidden_size), torch.nn.Tanh(), torch.nn.Linear(config.hidden_size, kv_size), ) else: self.embedding = torch.nn.Embedding( config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2 ) def forward(self, prefix: torch.Tensor): if self.prefix_projection: prefix_tokens = self.embedding(prefix) past_key_values = self.trans(prefix_tokens) else: past_key_values = self.embedding(prefix) return past_key_values class ChatGLMModel(ChatGLMPreTrainedModel): def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): super().__init__(config) if empty_init: init_method = skip_init else: init_method = default_init init_kwargs = {} if device is not None: init_kwargs["device"] = device self.embedding = init_method(Embedding, config, **init_kwargs) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels # Rotary positional embeddings self.seq_length = config.seq_length rotary_dim = ( config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels ) self.rotary_pos_emb = RotaryEmbedding( rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype ) self.encoder = init_method(GLMTransformer, config, **init_kwargs) self.output_layer = init_method( nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, dtype=config.torch_dtype, **init_kwargs, ) self.pre_seq_len = config.pre_seq_len self.prefix_projection = config.prefix_projection if self.pre_seq_len is not None: for param in self.parameters(): param.requires_grad = False self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = PrefixEncoder(config) self.dropout = torch.nn.Dropout(0.1) def get_input_embeddings(self): return self.embedding.word_embeddings def get_prompt(self, batch_size, device, dtype=torch.half): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) past_key_values = past_key_values.view( batch_size, self.pre_seq_len, self.num_layers * 2, self.multi_query_group_num, self.kv_channels ) # seq_len, b, nh, hidden_size past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) return past_key_values def forward( self, input_ids, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.BoolTensor] = None, full_attention_mask: Optional[torch.BoolTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length = input_ids.shape if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) if self.pre_seq_len is not None: if past_key_values is None: past_key_values = self.get_prompt( batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype ) if attention_mask is not None: attention_mask = torch.cat( [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 ) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) if position_ids is not None: rotary_pos_emb = rotary_pos_emb[position_ids] else: rotary_pos_emb = rotary_pos_emb[None, :seq_length] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() # Run encoder. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states, ) if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) ================================================ FILE: src/diffusers/pipelines/kolors/tokenizer.py ================================================ # Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import re from typing import Dict, List, Optional, Union from sentencepiece import SentencePieceProcessor from transformers import PreTrainedTokenizer from transformers.tokenization_utils_base import BatchEncoding, EncodedInput from transformers.utils import PaddingStrategy class SPTokenizer: def __init__(self, model_path: str): # reload tokenizer assert os.path.isfile(model_path), model_path self.sp_model = SentencePieceProcessor(model_file=model_path) # BOS / EOS token IDs self.n_words: int = self.sp_model.vocab_size() self.bos_id: int = self.sp_model.bos_id() self.eos_id: int = self.sp_model.eos_id() self.pad_id: int = self.sp_model.unk_id() assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens self.special_tokens = {} self.index_special_tokens = {} for token in special_tokens: self.special_tokens[token] = self.n_words self.index_special_tokens[self.n_words] = token self.n_words += 1 self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens]) def tokenize(self, s: str, encode_special_tokens=False): if encode_special_tokens: last_index = 0 t = [] for match in re.finditer(self.role_special_token_expression, s): if last_index < match.start(): t.extend(self.sp_model.EncodeAsPieces(s[last_index : match.start()])) t.append(s[match.start() : match.end()]) last_index = match.end() if last_index < len(s): t.extend(self.sp_model.EncodeAsPieces(s[last_index:])) return t else: return self.sp_model.EncodeAsPieces(s) def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: assert isinstance(s, str) t = self.sp_model.encode(s) if bos: t = [self.bos_id] + t if eos: t = t + [self.eos_id] return t def decode(self, t: List[int]) -> str: text, buffer = "", [] for token in t: if token in self.index_special_tokens: if buffer: text += self.sp_model.decode(buffer) buffer = [] text += self.index_special_tokens[token] else: buffer.append(token) if buffer: text += self.sp_model.decode(buffer) return text def decode_tokens(self, tokens: List[str]) -> str: text = self.sp_model.DecodePieces(tokens) return text def convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" if token in self.special_tokens: return self.special_tokens[token] return self.sp_model.PieceToId(token) def convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" if index in self.index_special_tokens: return self.index_special_tokens[index] if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0: return "" return self.sp_model.IdToPiece(index) class ChatGLMTokenizer(PreTrainedTokenizer): vocab_files_names = {"vocab_file": "tokenizer.model"} model_input_names = ["input_ids", "attention_mask", "position_ids"] def __init__( self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False, **kwargs, ): self.name = "GLMTokenizer" self.vocab_file = vocab_file self.tokenizer = SPTokenizer(vocab_file) self.special_tokens = { "": self.tokenizer.bos_id, "": self.tokenizer.eos_id, "": self.tokenizer.pad_id, } self.encode_special_tokens = encode_special_tokens super().__init__( padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, encode_special_tokens=encode_special_tokens, **kwargs, ) def get_command(self, token): if token in self.special_tokens: return self.special_tokens[token] assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}" return self.tokenizer.special_tokens[token] @property def unk_token(self) -> str: return "" @unk_token.setter def unk_token(self, value: str): self._unk_token = value @property def pad_token(self) -> str: return "" @pad_token.setter def pad_token(self, value: str): self._pad_token = value @property def pad_token_id(self): return self.get_command("") @property def eos_token(self) -> str: return "" @eos_token.setter def eos_token(self, value: str): self._eos_token = value @property def eos_token_id(self): return self.get_command("") @property def vocab_size(self): return self.tokenizer.n_words def get_vocab(self): """Returns vocab as a dict""" vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab def _tokenize(self, text, **kwargs): return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens) def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.tokenizer.convert_token_to_id(token) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" return self.tokenizer.convert_id_to_token(index) def convert_tokens_to_string(self, tokens: List[str]) -> str: return self.tokenizer.decode_tokens(tokens) def save_vocabulary(self, save_directory, filename_prefix=None): """ Save the vocabulary and special tokens file to a directory. Args: save_directory (`str`): The directory in which to save the vocabulary. filename_prefix (`str`, *optional*): An optional prefix to add to the named of the saved files. Returns: `Tuple(str)`: Paths to the files saved. """ if os.path.isdir(save_directory): vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"]) else: vocab_file = save_directory with open(self.vocab_file, "rb") as fin: proto_str = fin.read() with open(vocab_file, "wb") as writer: writer.write(proto_str) return (vocab_file,) def get_prefix_tokens(self): prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")] return prefix_tokens def build_single_message(self, role, metadata, message): assert role in ["system", "user", "assistant", "observation"], role role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n") message_tokens = self.tokenizer.encode(message) tokens = role_tokens + message_tokens return tokens def build_chat_input(self, query, history=None, role="user"): if history is None: history = [] input_ids = [] for item in history: content = item["content"] if item["role"] == "system" and "tools" in item: content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False) input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content)) input_ids.extend(self.build_single_message(role, "", query)) input_ids.extend([self.get_command("<|assistant|>")]) return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True) def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format: - single sequence: `[CLS] X [SEP]` - pair of sequences: `[CLS] A [SEP] B [SEP]` Args: token_ids_0 (`List[int]`): List of IDs to which the special tokens will be added. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ prefix_tokens = self.get_prefix_tokens() token_ids_0 = prefix_tokens + token_ids_0 if token_ids_1 is not None: token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("")] return token_ids_0 def _pad( self, encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ Pad encoded inputs (on left/right and up to predefined length or max length in the batch) Args: encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). max_length: maximum length of the returned list and optionally padding length (see below). Will truncate by taking into account the special tokens. padding_strategy: PaddingStrategy to use for padding. - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad The tokenizer padding sides are defined in self.padding_side: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta). return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ # Load from model defaults assert self.padding_side == "left" required_input = encoded_inputs[self.model_input_names[0]] seq_length = len(required_input) if padding_strategy == PaddingStrategy.LONGEST: max_length = len(required_input) if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length # Initialize attention mask if not present. if "attention_mask" not in encoded_inputs: encoded_inputs["attention_mask"] = [1] * seq_length if "position_ids" not in encoded_inputs: encoded_inputs["position_ids"] = list(range(seq_length)) if needs_to_be_padded: difference = max_length - len(required_input) if "attention_mask" in encoded_inputs: encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] if "position_ids" in encoded_inputs: encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"] encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input return encoded_inputs ================================================ FILE: src/diffusers/pipelines/latent_consistency_models/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_latent_consistency_img2img"] = ["LatentConsistencyModelImg2ImgPipeline"] _import_structure["pipeline_latent_consistency_text2img"] = ["LatentConsistencyModelPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py ================================================ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import inspect from typing import Any, Callable, Dict, List, Optional, Union import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import LCMScheduler from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import AutoPipelineForImage2Image >>> import torch >>> import PIL >>> pipe = AutoPipelineForImage2Image.from_pretrained("SimianLuo/LCM_Dreamshaper_v7") >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality. >>> pipe.to(torch_device="cuda", torch_dtype=torch.float32) >>> prompt = "High altitude snowy mountains" >>> image = PIL.Image.open("./snowy_mountains.png") >>> # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps. >>> num_inference_steps = 4 >>> images = pipe( ... prompt=prompt, image=image, num_inference_steps=num_inference_steps, guidance_scale=8.0 ... ).images >>> images[0].save("image.png") ``` """ class LatentConsistencyModelImg2ImgPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, ): r""" Pipeline for image-to-image generation using a latent consistency model. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only supports [`LCMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. requires_safety_checker (`bool`, *optional*, defaults to `True`): Whether the pipeline requires a safety checker component. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: LCMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def check_inputs( self, prompt: Union[str, List[str]], strength: float, callback_steps: int, prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) @property def guidance_scale(self): return self._guidance_scale @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def clip_skip(self): return self._clip_skip @property def do_classifier_free_guidance(self): return False @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, num_inference_steps: int = 4, strength: float = 0.8, original_inference_steps: int = None, timesteps: List[int] = None, guidance_scale: float = 8.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. original_inference_steps (`int`, *optional*): The original number of inference steps use to generate a linearly-spaced timestep schedule, from which we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule, following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the scheduler's `original_inference_steps` attribute. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. Note that the original latent consistency models paper uses a different CFG formulation where the guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale > 0`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, strength, callback_steps, prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) # NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided # distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the # unconditional prompt "" (the empty string). Due to this, LCMs currently do not support negative prompts. prompt_embeds, _ = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt=None, prompt_embeds=prompt_embeds, negative_prompt_embeds=None, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # 4. Encode image image = self.image_processor.preprocess(image) # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, original_inference_steps=original_inference_steps, strength=strength, ) # 6. Prepare latent variables original_inference_steps = ( original_inference_steps if original_inference_steps is not None else self.scheduler.config.original_inference_steps ) latent_timestep = timesteps[:1] if latents is None: latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) bs = batch_size * num_images_per_prompt # 6. Get Guidance Scale Embedding # NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper # CFG formulation, so we need to subtract 1 from the input guidance_scale. # LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG) w = torch.tensor(self.guidance_scale - 1).repeat(bs) w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).to( device=device, dtype=latents.dtype ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 8. LCM Multistep Sampling Loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latents = latents.to(prompt_embeds.dtype) # model prediction (v-prediction, eps, x) model_pred = self.unet( latents, t, timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # compute the previous noisy sample x_t -> x_t-1 latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) w_embedding = callback_outputs.pop("w_embedding", w_embedding) denoised = callback_outputs.pop("denoised", denoised) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) denoised = denoised.to(prompt_embeds.dtype) if not output_type == "latent": image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = denoised has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py ================================================ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import LCMScheduler from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import DiffusionPipeline >>> import torch >>> pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7") >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality. >>> pipe.to(torch_device="cuda", torch_dtype=torch.float32) >>> prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" >>> # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps. >>> num_inference_steps = 4 >>> images = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=8.0).images >>> images[0].save("image.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class LatentConsistencyModelPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using a latent consistency model. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only supports [`LCMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. requires_safety_checker (`bool`, *optional*, defaults to `True`): Whether the pipeline requires a safety checker component. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: LCMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Currently StableDiffusionPipeline.check_inputs with negative prompt stuff removed def check_inputs( self, prompt: Union[str, List[str]], height: int, width: int, callback_steps: int, prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) @property def guidance_scale(self): return self._guidance_scale @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def clip_skip(self): return self._clip_skip @property def do_classifier_free_guidance(self): return False @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 4, original_inference_steps: int = None, timesteps: List[int] = None, guidance_scale: float = 8.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. original_inference_steps (`int`, *optional*): The original number of inference steps use to generate a linearly-spaced timestep schedule, from which we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule, following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the scheduler's `original_inference_steps` attribute. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. Note that the original latent consistency models paper uses a different CFG formulation where the guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale > 0`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) # NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided # distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the # unconditional prompt "" (the empty string). Due to this, LCMs currently do not support negative prompts. prompt_embeds, _ = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt=None, prompt_embeds=prompt_embeds, negative_prompt_embeds=None, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, original_inference_steps=original_inference_steps ) # 5. Prepare latent variable num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) bs = batch_size * num_images_per_prompt # 6. Get Guidance Scale Embedding # NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper # CFG formulation, so we need to subtract 1 from the input guidance_scale. # LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG) w = torch.tensor(self.guidance_scale - 1).repeat(bs) w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).to( device=device, dtype=latents.dtype ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 8. LCM MultiStep Sampling Loop: num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latents = latents.to(prompt_embeds.dtype) # model prediction (v-prediction, eps, x) model_pred = self.unet( latents, t, timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # compute the previous noisy sample x_t -> x_t-1 latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) w_embedding = callback_outputs.pop("w_embedding", w_embedding) denoised = callback_outputs.pop("denoised", denoised) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) denoised = denoised.to(prompt_embeds.dtype) if not output_type == "latent": image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = denoised has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/latent_diffusion/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_latent_diffusion"] = ["LDMBertModel", "LDMTextToImagePipeline"] _import_structure["pipeline_latent_diffusion_superresolution"] = ["LDMSuperResolutionPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput from transformers.utils import logging from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class LDMTextToImagePipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using latent diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [`~transformers.BERT`]. tokenizer ([`~transformers.BertTokenizer`]): A `BertTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "bert->unet->vqvae" def __init__( self, vqvae: Union[VQModel, AutoencoderKL], bert: PreTrainedModel, tokenizer: PreTrainedTokenizer, unet: Union[UNet2DModel, UNet2DConditionModel], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], ): super().__init__() self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 1.0, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 1.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. Example: ```py >>> from diffusers import DiffusionPipeline >>> # load model and scheduler >>> ldm = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") >>> # run pipeline in inference (sample random noise and denoise) >>> prompt = "A painting of a squirrel eating a burger" >>> images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images >>> # save images >>> for idx, image in enumerate(images): ... image.save(f"squirrel-{idx}.png") ``` Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get unconditional embeddings for classifier free guidance if guidance_scale != 1.0: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ) negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self._execution_device))[0] # get prompt text embeddings text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") prompt_embeds = self.bert(text_input.input_ids.to(self._execution_device))[0] # get the initial random noise unless the user supplied it latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor( latents_shape, generator=generator, device=self._execution_device, dtype=prompt_embeds.dtype ) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") latents = latents.to(self._execution_device) self.scheduler.set_timesteps(num_inference_steps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_kwargs = {} if accepts_eta: extra_kwargs["eta"] = eta for t in self.progress_bar(self.scheduler.timesteps): if guidance_scale == 1.0: # guidance_scale of 1 means no guidance latents_input = latents context = prompt_embeds else: # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = torch.cat([latents] * 2) context = torch.cat([negative_prompt_embeds, prompt_embeds]) # predict the noise residual noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample # perform guidance if guidance_scale != 1.0: noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / self.vqvae.config.scaling_factor * latents image = self.vqvae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ################################################################################ # Code for the text transformer model ################################################################################ """ PyTorch LDMBERT model.""" logger = logging.get_logger(__name__) LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ "ldm-bert", # See all LDMBert models at https://huggingface.co/models?filter=ldmbert ] LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json", } """ LDMBERT model configuration""" class LDMBertConfig(PretrainedConfig): model_type = "ldmbert" keys_to_ignore_at_inference = ["past_key_values"] attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} def __init__( self, vocab_size=30522, max_position_embeddings=77, encoder_layers=32, encoder_ffn_dim=5120, encoder_attention_heads=8, head_dim=64, encoder_layerdrop=0.0, activation_function="gelu", d_model=1280, dropout=0.1, attention_dropout=0.0, activation_dropout=0.0, init_std=0.02, classifier_dropout=0.0, scale_embedding=False, use_cache=True, pad_token_id=0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.d_model = d_model self.encoder_ffn_dim = encoder_ffn_dim self.encoder_layers = encoder_layers self.encoder_attention_heads = encoder_attention_heads self.head_dim = head_dim self.dropout = dropout self.attention_dropout = attention_dropout self.activation_dropout = activation_dropout self.activation_function = activation_function self.init_std = init_std self.encoder_layerdrop = encoder_layerdrop self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__(pad_token_id=pad_token_id, **kwargs) def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert class LDMBertAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, embed_dim: int, num_heads: int, head_dim: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = False, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = head_dim self.inner_dim = head_dim * num_heads self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) self.out_proj = nn.Linear(self.inner_dim, embed_dim) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value class LDMBertEncoderLayer(nn.Module): def __init__(self, config: LDMBertConfig): super().__init__() self.embed_dim = config.d_model self.self_attn = LDMBertAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, head_dim=config.head_dim, dropout=config.attention_dropout, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, layer_head_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: hidden_states (`torch.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (`torch.Tensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.Tensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs # Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert class LDMBertPreTrainedModel(PreTrainedModel): config_class = LDMBertConfig base_model_prefix = "model" _supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (LDMBertEncoder,)): module.gradient_checkpointing = value @property def dummy_inputs(self): pad_token = self.config.pad_token_id input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) dummy_inputs = { "attention_mask": input_ids.ne(pad_token), "input_ids": input_ids, } return dummy_inputs class LDMBertEncoder(LDMBertPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`LDMBertEncoderLayer`]. Args: config: LDMBertConfig embed_tokens (nn.Embedding): output embedding """ def __init__(self, config: LDMBertConfig): super().__init__(config) self.dropout = config.dropout embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) seq_len = input_shape[1] if position_ids is None: position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) embed_pos = self.embed_positions(position_ids) hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: if head_mask.size()[0] != (len(self.layers)): raise ValueError( f"The head_mask should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(encoder_layer), hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) class LDMBertModel(LDMBertPreTrainedModel): _no_split_modules = [] def __init__(self, config: LDMBertConfig): super().__init__(config) self.model = LDMBertEncoder(config) self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) return outputs ================================================ FILE: src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py ================================================ import inspect from typing import List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.utils.checkpoint from ...models import UNet2DModel, VQModel from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from ...utils import PIL_INTERPOLATION from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput def preprocess(image): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.0 * image - 1.0 class LDMSuperResolutionPipeline(DiffusionPipeline): r""" A pipeline for image super-resolution using latent diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) model to encode and decode images to and from latent representations. unet ([`UNet2DModel`]): A `UNet2DModel` to denoise the encoded image. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vqvae: VQModel, unet: UNet2DModel, scheduler: Union[ DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], ): super().__init__() self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, image: Union[torch.Tensor, PIL.Image.Image] = None, batch_size: Optional[int] = 1, num_inference_steps: Optional[int] = 100, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[Tuple, ImagePipelineOutput]: r""" The call function to the pipeline for generation. Args: image (`torch.Tensor` or `PIL.Image.Image`): `Image` or tensor representing an image batch to be used as the starting point for the process. batch_size (`int`, *optional*, defaults to 1): Number of images to generate. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. Example: ```py >>> import requests >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import LDMSuperResolutionPipeline >>> import torch >>> # load model and scheduler >>> pipeline = LDMSuperResolutionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages") >>> pipeline = pipeline.to("cuda") >>> # let's download an image >>> url = ( ... "https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png" ... ) >>> response = requests.get(url) >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") >>> low_res_img = low_res_img.resize((128, 128)) >>> # run pipeline in inference (sample random noise and denoise) >>> upscaled_image = pipeline(low_res_img, num_inference_steps=100, eta=1).images[0] >>> # save image >>> upscaled_image.save("ldm_generated_image.png") ``` Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, torch.Tensor): batch_size = image.shape[0] else: raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") if isinstance(image, PIL.Image.Image): image = preprocess(image) height, width = image.shape[-2:] # in_channels should be 6: 3 for latents, 3 for low resolution image latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width) latents_dtype = next(self.unet.parameters()).dtype latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) image = image.to(device=self.device, dtype=latents_dtype) # set timesteps and move to the correct device self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps_tensor = self.scheduler.timesteps # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_kwargs = {} if accepts_eta: extra_kwargs["eta"] = eta for t in self.progress_bar(timesteps_tensor): # concat latents and low resolution image in the channel dimension. latents_input = torch.cat([latents, image], dim=1) latents_input = self.scheduler.scale_model_input(latents_input, t) # predict the noise residual noise_pred = self.unet(latents_input, t).sample # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample # decode the image latents with the VQVAE image = self.vqvae.decode(latents).sample image = torch.clamp(image, -1.0, 1.0) image = image / 2 + 0.5 image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/latte/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_latte"] = ["LattePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_latte import LattePipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/latte/pipeline_latte.py ================================================ # Copyright 2024 the Latte Team and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import html import inspect import re import urllib.parse as ul from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union import torch from transformers import T5EncoderModel, T5Tokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKL, LatteTransformer3DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( BACKENDS_MAPPING, BaseOutput, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, ) from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import LattePipeline >>> from diffusers.utils import export_to_gif >>> # You can replace the checkpoint id with "maxin-cn/Latte-1" too. >>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to("cuda") >>> # Enable memory optimizations. >>> pipe.enable_model_cpu_offload() >>> prompt = "A small cactus with a happy face in the Sahara desert." >>> videos = pipe(prompt).frames[0] >>> export_to_gif(videos, "latte.gif") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps @dataclass class LattePipelineOutput(BaseOutput): frames: torch.Tensor class LattePipeline(DiffusionPipeline): r""" Pipeline for text-to-video generation using Latte. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. text_encoder ([`T5EncoderModel`]): Frozen text-encoder. Latte uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. tokenizer (`T5Tokenizer`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). transformer ([`LatteTransformer3DModel`]): A text conditioned `LatteTransformer3DModel` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ bad_punct_regex = re.compile(r"[#®•©™&@·º½¾¿¡§~\)\(\]\[\}\{\|\\/\\*]{1,}") _optional_components = ["tokenizer", "text_encoder"] model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", ] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, vae: AutoencoderKL, transformer: LatteTransformer3DModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py def mask_text_embeddings(self, emb, mask): if emb.shape[0] == 1: keep_index = mask.sum().item() return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096 else: masked_feature = emb * mask[:, None, :, None] # 1 120 4096 return masked_feature, emb.shape[2] # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, negative_prompt: str = "", num_images_per_prompt: int = 1, device: Optional[torch.device] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, mask_feature: bool = True, dtype=None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For Latte, this should be "". do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of video that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. For Latte, it's should be the embeddings of the "" string. clean_caption (bool, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. mask_feature: (bool, defaults to `True`): If `True`, the function will mask the text embeddings. """ embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] max_length = 120 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds_attention_mask = attention_mask prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds_attention_mask = torch.ones_like(prompt_embeds) if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.transformer is not None: dtype = self.transformer.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1) prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None # Perform additional masking. if mask_feature and not embeds_initially_provided: prompt_embeds = prompt_embeds.unsqueeze(1) masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask) masked_prompt_embeds = masked_prompt_embeds.squeeze(1) masked_negative_prompt_embeds = ( negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None ) return masked_prompt_embeds, masked_negative_prompt_embeds return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @property def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: str = "", num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 7.5, num_images_per_prompt: int = 1, video_length: int = 16, height: int = 512, width: int = 512, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: str = "pil", return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], clean_caption: bool = True, mask_feature: bool = True, enable_temporal_attentions: bool = True, decode_chunk_size: Optional[int] = None, ) -> Union[LattePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. video_length (`int`, *optional*, defaults to 16): The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds num_images_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): The height in pixels of the generated video. width (`int`, *optional*, defaults to self.unet.config.sample_size): The width in pixels of the generated video. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. For Latte this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate video. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A callback function or a list of callback functions to be called at the end of each denoising step. callback_on_step_end_tensor_inputs (`List[str]`, *optional*): A list of tensor inputs that should be passed to the callback function. If not defined, all tensor inputs will be passed. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. enable_temporal_attentions (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions decode_chunk_size (`int`, *optional*): The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality. For lower memory usage, reduce `decode_chunk_size`. Examples: Returns: [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else video_length # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor self.check_inputs( prompt, height, width, negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds, negative_prompt_embeds, ) self._guidance_scale = guidance_scale self._interrupt = False # 2. Default height and width to transformer if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, device=device, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, mask_feature=mask_feature, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) self._num_timesteps = len(timesteps) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, video_length, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) current_timestep = t if not torch.is_tensor(current_timestep): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" if isinstance(current_timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML current_timestep = current_timestep.expand(latent_model_input.shape[0]) # predict noise model_output noise_pred = self.transformer( latent_model_input, encoder_hidden_states=prompt_embeds, timestep=current_timestep, enable_temporal_attentions=enable_temporal_attentions, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # use learned sigma? if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred = noise_pred.chunk(2, dim=1)[0] # compute previous video: x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if not output_type == "latents": video = self.decode_latents(latents, video_length, decode_chunk_size=14) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: video = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return LattePipelineOutput(frames=video) # Similar to diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.decode_latents def decode_latents(self, latents: torch.Tensor, video_length: int, decode_chunk_size: int = 14): # [batch, channels, frames, height, width] -> [batch*frames, channels, height, width] latents = latents.permute(0, 2, 1, 3, 4).flatten(0, 1) latents = 1 / self.vae.config.scaling_factor * latents forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) # decode decode_chunk_size frames at a time to avoid OOM frames = [] for i in range(0, latents.shape[0], decode_chunk_size): num_frames_in = latents[i : i + decode_chunk_size].shape[0] decode_kwargs = {} if accepts_num_frames: # we only pass num_frames_in if it's expected decode_kwargs["num_frames"] = num_frames_in frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample frames.append(frame) frames = torch.cat(frames, dim=0) # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] frames = frames.reshape(-1, video_length, *frames.shape[1:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 frames = frames.float() return frames ================================================ FILE: src/diffusers/pipelines/ledits_pp/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_leditspp_stable_diffusion"] = ["LEditsPPPipelineStableDiffusion"] _import_structure["pipeline_leditspp_stable_diffusion_xl"] = ["LEditsPPPipelineStableDiffusionXL"] _import_structure["pipeline_output"] = ["LEditsPPDiffusionPipelineOutput", "LEditsPPDiffusionPipelineOutput"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_leditspp_stable_diffusion import ( LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput, LEditsPPPipelineStableDiffusion, ) from .pipeline_leditspp_stable_diffusion_xl import LEditsPPPipelineStableDiffusionXL else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py ================================================ import inspect import math from itertools import repeat from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention, AttnProcessor from ...models.lora import adjust_lora_scale_text_encoder from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import LEditsPPPipelineStableDiffusion >>> from diffusers.utils import load_image >>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained( ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png" >>> image = load_image(img_url).convert("RGB") >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1) >>> edited_image = pipe( ... editing_prompt=["cherry blossom"], edit_guidance_scale=10.0, edit_threshold=0.75 ... ).images[0] ``` """ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.AttentionStore class LeditsAttentionStore: @staticmethod def get_empty_store(): return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False): # attn.shape = batch_size * head_size, seq_len query, seq_len_key if attn.shape[1] <= self.max_size: bs = 1 + int(PnP) + editing_prompts skip = 2 if PnP else 1 # skip PnP & unconditional attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3) source_batch_size = int(attn.shape[1] // bs) self.forward(attn[:, skip * source_batch_size :], is_cross, place_in_unet) def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" self.step_store[key].append(attn) def between_steps(self, store_step=True): if store_step: if self.average: if len(self.attention_store) == 0: self.attention_store = self.step_store else: for key in self.attention_store: for i in range(len(self.attention_store[key])): self.attention_store[key][i] += self.step_store[key][i] else: if len(self.attention_store) == 0: self.attention_store = [self.step_store] else: self.attention_store.append(self.step_store) self.cur_step += 1 self.step_store = self.get_empty_store() def get_attention(self, step: int): if self.average: attention = { key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store } else: assert step is not None attention = self.attention_store[step] return attention def aggregate_attention( self, attention_maps, prompts, res: Union[int, Tuple[int]], from_where: List[str], is_cross: bool, select: int ): out = [[] for x in range(self.batch_size)] if isinstance(res, int): num_pixels = res**2 resolution = (res, res) else: num_pixels = res[0] * res[1] resolution = res[:2] for location in from_where: for bs_item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: for batch, item in enumerate(bs_item): if item.shape[1] == num_pixels: cross_maps = item.reshape(len(prompts), -1, *resolution, item.shape[-1])[select] out[batch].append(cross_maps) out = torch.stack([torch.cat(x, dim=0) for x in out]) # average over heads out = out.sum(1) / out.shape[1] return out def __init__(self, average: bool, batch_size=1, max_resolution=16, max_size: int = None): self.step_store = self.get_empty_store() self.attention_store = [] self.cur_step = 0 self.average = average self.batch_size = batch_size if max_size is None: self.max_size = max_resolution**2 elif max_size is not None and max_resolution is None: self.max_size = max_size else: raise ValueError("Only allowed to set one of max_resolution or max_size") # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.GaussianSmoothing class LeditsGaussianSmoothing: def __init__(self, device): kernel_size = [3, 3] sigma = [0.5, 0.5] # The gaussian kernel is the product of the gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1)) self.weight = kernel.to(device) def __call__(self, input): """ Arguments: Apply gaussian filter to input. input (torch.Tensor): Input to apply gaussian filter on. Returns: filtered (torch.Tensor): Filtered output. """ return F.conv2d(input, weight=self.weight.to(input.dtype)) class LEDITSCrossAttnProcessor: def __init__(self, attention_store, place_in_unet, pnp, editing_prompts): self.attnstore = attention_store self.place_in_unet = place_in_unet self.editing_prompts = editing_prompts self.pnp = pnp def __call__( self, attn: Attention, hidden_states, encoder_hidden_states, attention_mask=None, temb=None, ): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) self.attnstore( attention_probs, is_cross=True, place_in_unet=self.place_in_unet, editing_prompts=self.editing_prompts, PnP=self.pnp, ) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states / attn.rescale_output_factor return hidden_states # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class LEditsPPPipelineStableDiffusion( DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin ): """ Pipeline for textual image editing using LEDits++ with Stable Diffusion. This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer ([`~transformers.CLIPTokenizer`]): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will automatically be set to [`DPMSolverMultistepScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`~transformers.CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, DPMSolverMultistepScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler): scheduler = DPMSolverMultistepScheduler.from_config( scheduler.config, algorithm_type="sde-dpmsolver++", solver_order=2 ) logger.warning( "This pipeline only supports DDIMScheduler and DPMSolverMultistepScheduler. " "The scheduler has been changed to DPMSolverMultistepScheduler." ) if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.inversion_steps = None # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, eta, generator=None): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, negative_prompt=None, editing_prompt_embeddings=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if editing_prompt_embeddings is not None and negative_prompt_embeds is not None: if editing_prompt_embeddings.shape != negative_prompt_embeds.shape: raise ValueError( "`editing_prompt_embeddings` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `editing_prompt_embeddings` {editing_prompt_embeddings.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents): # shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) # if latents.shape != shape: # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_unet(self, attention_store, PnP: bool = False): attn_procs = {} for name in self.unet.attn_processors.keys(): if name.startswith("mid_block"): place_in_unet = "mid" elif name.startswith("up_blocks"): place_in_unet = "up" elif name.startswith("down_blocks"): place_in_unet = "down" else: continue if "attn2" in name and place_in_unet != "mid": attn_procs[name] = LEDITSCrossAttnProcessor( attention_store=attention_store, place_in_unet=place_in_unet, pnp=PnP, editing_prompts=self.enabled_editing_prompts, ) else: attn_procs[name] = AttnProcessor() self.unet.set_attn_processor(attn_procs) def encode_prompt( self, device, num_images_per_prompt, enable_edit_guidance, negative_prompt=None, editing_prompt=None, negative_prompt_embeds: Optional[torch.Tensor] = None, editing_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt enable_edit_guidance (`bool`): whether to perform any editing or reconstruct the input image instead negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). editing_prompt (`str` or `List[str]`, *optional*): Editing prompt(s) to be encoded. If not defined, one has to pass `editing_prompt_embeds` instead. editing_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) batch_size = self.batch_size num_edit_tokens = None if negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but exoected" f"{batch_size} based on the input images. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = negative_prompt_embeds.dtype negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) if enable_edit_guidance: if editing_prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary # if isinstance(self, TextualInversionLoaderMixin): # prompt = self.maybe_convert_prompt(prompt, self.tokenizer) if isinstance(editing_prompt, str): editing_prompt = [editing_prompt] max_length = negative_prompt_embeds.shape[1] text_inputs = self.tokenizer( [x for item in editing_prompt for x in repeat(item, batch_size)], padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", return_length=True, ) num_edit_tokens = text_inputs.length - 2 # not counting startoftext and endoftext text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer( [x for item in editing_prompt for x in repeat(item, batch_size)], padding="longest", return_tensors="pt", ).input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if ( hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask ): attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: editing_prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) editing_prompt_embeds = editing_prompt_embeds[0] else: editing_prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. editing_prompt_embeds = editing_prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. editing_prompt_embeds = self.text_encoder.text_model.final_layer_norm(editing_prompt_embeds) editing_prompt_embeds = editing_prompt_embeds.to(dtype=negative_prompt_embeds.dtype, device=device) bs_embed_edit, seq_len, _ = editing_prompt_embeds.shape editing_prompt_embeds = editing_prompt_embeds.to(dtype=negative_prompt_embeds.dtype, device=device) editing_prompt_embeds = editing_prompt_embeds.repeat(1, num_images_per_prompt, 1) editing_prompt_embeds = editing_prompt_embeds.view(bs_embed_edit * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return editing_prompt_embeds, negative_prompt_embeds, num_edit_tokens @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, negative_prompt: Optional[Union[str, List[str]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, editing_prompt: Optional[Union[str, List[str]]] = None, editing_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, edit_guidance_scale: Optional[Union[float, List[float]]] = 5, edit_warmup_steps: Optional[Union[int, List[int]]] = 0, edit_cooldown_steps: Optional[Union[int, List[int]]] = None, edit_threshold: Optional[Union[float, List[float]]] = 0.9, user_mask: Optional[torch.Tensor] = None, sem_guidance: Optional[List[torch.Tensor]] = None, use_cross_attn_mask: bool = False, use_intersect_mask: bool = True, attn_store_steps: Optional[List[int]] = [], store_averaged_over_steps: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for editing. The [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusion.invert`] method has to be called beforehand. Edits will always be performed for the last inverted image(s). Args: negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] instead of a plain tuple. editing_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. The image is reconstructed by setting `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`. editing_prompt_embeds (`torch.Tensor>`, *optional*): Pre-computed embeddings to use for guiding the image generation. Guidance direction of embedding should be specified via `reverse_editing_direction`. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): Whether the corresponding prompt in `editing_prompt` should be increased or decreased. edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): Guidance scale for guiding the image generation. If provided as list values should correspond to `editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 12 of [LEDITS++ Paper](https://arxiv.org/abs/2301.12247). edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): Number of diffusion steps (for each prompt) for which guidance will not be applied. edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): Number of diffusion steps (for each prompt) after which guidance will no longer be applied. edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): Masking threshold of guidance. Threshold should be proportional to the image region that is modified. 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++ Paper](https://arxiv.org/abs/2301.12247). user_mask (`torch.Tensor`, *optional*): User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s implicit masks do not meet user preferences. sem_guidance (`List[torch.Tensor]`, *optional*): List of pre-generated guidance vectors to be applied at generation. Length of the list has to correspond to `num_inference_steps`. use_cross_attn_mask (`bool`, defaults to `False`): Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf). use_intersect_mask (`bool`, defaults to `True`): Whether the masking term is calculated as intersection of cross-attention masks and masks derived from the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise estimate are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf). attn_store_steps (`List[int]`, *optional*): Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes. store_averaged_over_steps (`bool`, defaults to `True`): Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps. If False, attention maps for each step are stores separately. Just for visualization purposes. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] or `tuple`: [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ if self.inversion_steps is None: raise ValueError( "You need to invert an input image first before calling the pipeline. The `invert` method has to be called beforehand. Edits will always be performed for the last inverted image(s)." ) eta = self.eta num_images_per_prompt = 1 latents = self.init_latents zs = self.zs self.scheduler.set_timesteps(len(self.scheduler.timesteps)) if use_intersect_mask: use_cross_attn_mask = True if use_cross_attn_mask: self.smoothing = LeditsGaussianSmoothing(self.device) if user_mask is not None: user_mask = user_mask.to(self.device) org_prompt = "" # 1. Check inputs. Raise error if not correct self.check_inputs( negative_prompt, editing_prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters batch_size = self.batch_size if editing_prompt: enable_edit_guidance = True if isinstance(editing_prompt, str): editing_prompt = [editing_prompt] self.enabled_editing_prompts = len(editing_prompt) elif editing_prompt_embeds is not None: enable_edit_guidance = True self.enabled_editing_prompts = editing_prompt_embeds.shape[0] else: self.enabled_editing_prompts = 0 enable_edit_guidance = False # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) edit_concepts, uncond_embeddings, num_edit_tokens = self.encode_prompt( editing_prompt=editing_prompt, device=self.device, num_images_per_prompt=num_images_per_prompt, enable_edit_guidance=enable_edit_guidance, negative_prompt=negative_prompt, editing_prompt_embeds=editing_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if enable_edit_guidance: text_embeddings = torch.cat([uncond_embeddings, edit_concepts]) self.text_cross_attention_maps = [editing_prompt] if isinstance(editing_prompt, str) else editing_prompt else: text_embeddings = torch.cat([uncond_embeddings]) # 4. Prepare timesteps # self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps = self.inversion_steps t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0] :])} if use_cross_attn_mask: self.attention_store = LeditsAttentionStore( average=store_averaged_over_steps, batch_size=batch_size, max_size=(latents.shape[-2] / 4.0) * (latents.shape[-1] / 4.0), max_resolution=None, ) self.prepare_unet(self.attention_store, PnP=False) resolution = latents.shape[-2:] att_res = (int(resolution[0] / 4), int(resolution[1] / 4)) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, None, None, text_embeddings.dtype, self.device, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(eta) self.sem_guidance = None self.activation_mask = None # 7. Denoising loop num_warmup_steps = 0 with self.progress_bar(total=len(timesteps)) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance if enable_edit_guidance: latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts)) else: latent_model_input = latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) text_embed_input = text_embeddings # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input).sample noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64] noise_pred_uncond = noise_pred_out[0] noise_pred_edit_concepts = noise_pred_out[1:] noise_guidance_edit = torch.zeros( noise_pred_uncond.shape, device=self.device, dtype=noise_pred_uncond.dtype, ) if sem_guidance is not None and len(sem_guidance) > i: noise_guidance_edit += sem_guidance[i].to(self.device) elif enable_edit_guidance: if self.activation_mask is None: self.activation_mask = torch.zeros( (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) ) if self.sem_guidance is None: self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape)) for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): if isinstance(edit_warmup_steps, list): edit_warmup_steps_c = edit_warmup_steps[c] else: edit_warmup_steps_c = edit_warmup_steps if i < edit_warmup_steps_c: continue if isinstance(edit_guidance_scale, list): edit_guidance_scale_c = edit_guidance_scale[c] else: edit_guidance_scale_c = edit_guidance_scale if isinstance(edit_threshold, list): edit_threshold_c = edit_threshold[c] else: edit_threshold_c = edit_threshold if isinstance(reverse_editing_direction, list): reverse_editing_direction_c = reverse_editing_direction[c] else: reverse_editing_direction_c = reverse_editing_direction if isinstance(edit_cooldown_steps, list): edit_cooldown_steps_c = edit_cooldown_steps[c] elif edit_cooldown_steps is None: edit_cooldown_steps_c = i + 1 else: edit_cooldown_steps_c = edit_cooldown_steps if i >= edit_cooldown_steps_c: continue noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond if reverse_editing_direction_c: noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c if user_mask is not None: noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask if use_cross_attn_mask: out = self.attention_store.aggregate_attention( attention_maps=self.attention_store.step_store, prompts=self.text_cross_attention_maps, res=att_res, from_where=["up", "down"], is_cross=True, select=self.text_cross_attention_maps.index(editing_prompt[c]), ) attn_map = out[:, :, :, 1 : 1 + num_edit_tokens[c]] # 0 -> startoftext # average over all tokens if attn_map.shape[3] != num_edit_tokens[c]: raise ValueError( f"Incorrect shape of attention_map. Expected size {num_edit_tokens[c]}, but found {attn_map.shape[3]}!" ) attn_map = torch.sum(attn_map, dim=3) # gaussian_smoothing attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect") attn_map = self.smoothing(attn_map).squeeze(1) # torch.quantile function expects float32 if attn_map.dtype == torch.float32: tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1) else: tmp = torch.quantile( attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1 ).to(attn_map.dtype) attn_mask = torch.where( attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1, *att_res), 1.0, 0.0 ) # resolution must match latent space dimension attn_mask = F.interpolate( attn_mask.unsqueeze(1), noise_guidance_edit_tmp.shape[-2:], # 64,64 ).repeat(1, 4, 1, 1) self.activation_mask[i, c] = attn_mask.detach().cpu() if not use_intersect_mask: noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask if use_intersect_mask: if t <= 800: noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) noise_guidance_edit_tmp_quantile = torch.sum( noise_guidance_edit_tmp_quantile, dim=1, keepdim=True ) noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat( 1, self.unet.config.in_channels, 1, 1 ) # torch.quantile function expects float32 if noise_guidance_edit_tmp_quantile.dtype == torch.float32: tmp = torch.quantile( noise_guidance_edit_tmp_quantile.flatten(start_dim=2), edit_threshold_c, dim=2, keepdim=False, ) else: tmp = torch.quantile( noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), edit_threshold_c, dim=2, keepdim=False, ).to(noise_guidance_edit_tmp_quantile.dtype) intersect_mask = ( torch.where( noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], torch.ones_like(noise_guidance_edit_tmp), torch.zeros_like(noise_guidance_edit_tmp), ) * attn_mask ) self.activation_mask[i, c] = intersect_mask.detach().cpu() noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask else: # print(f"only attention mask for step {i}") noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask elif not use_cross_attn_mask: # calculate quantile noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) noise_guidance_edit_tmp_quantile = torch.sum( noise_guidance_edit_tmp_quantile, dim=1, keepdim=True ) noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1) # torch.quantile function expects float32 if noise_guidance_edit_tmp_quantile.dtype == torch.float32: tmp = torch.quantile( noise_guidance_edit_tmp_quantile.flatten(start_dim=2), edit_threshold_c, dim=2, keepdim=False, ) else: tmp = torch.quantile( noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), edit_threshold_c, dim=2, keepdim=False, ).to(noise_guidance_edit_tmp_quantile.dtype) self.activation_mask[i, c] = ( torch.where( noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], torch.ones_like(noise_guidance_edit_tmp), torch.zeros_like(noise_guidance_edit_tmp), ) .detach() .cpu() ) noise_guidance_edit_tmp = torch.where( noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], noise_guidance_edit_tmp, torch.zeros_like(noise_guidance_edit_tmp), ) noise_guidance_edit += noise_guidance_edit_tmp self.sem_guidance[i] = noise_guidance_edit.detach().cpu() noise_pred = noise_pred_uncond + noise_guidance_edit if enable_edit_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg( noise_pred, noise_pred_edit_concepts.mean(dim=0, keepdim=False), guidance_rescale=self.guidance_rescale, ) idx = t_to_idx[int(t)] latents = self.scheduler.step( noise_pred, t, latents, variance_noise=zs[idx], **extra_step_kwargs ).prev_sample # step callback if use_cross_attn_mask: store_step = i in attn_store_steps self.attention_store.between_steps(store_step) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # 8. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return LEditsPPDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) @torch.no_grad() def invert( self, image: PipelineImageInput, source_prompt: str = "", source_guidance_scale: float = 3.5, num_inversion_steps: int = 30, skip: float = 0.15, generator: Optional[torch.Generator] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, resize_mode: Optional[str] = "default", crops_coords: Optional[Tuple[int, int, int, int]] = None, ): r""" The function to the pipeline for image inversion as described by the [LEDITS++ Paper](https://arxiv.org/abs/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140) will be performed instead. Args: image (`PipelineImageInput`): Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect ratio. source_prompt (`str`, defaults to `""`): Prompt describing the input image that will be used for guidance during inversion. Guidance is disabled if the `source_prompt` is `""`. source_guidance_scale (`float`, defaults to `3.5`): Strength of guidance during inversion. num_inversion_steps (`int`, defaults to `30`): Number of total performed inversion steps after discarding the initial `skip` steps. skip (`float`, defaults to `0.15`): Portion of initial steps that will be ignored for inversion and subsequent generation. Lower values will lead to stronger changes to the input image. `skip` has to be between `0` and `1`. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make inversion deterministic. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. height (`int`, *optional*, defaults to `None`): The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height. width (`int`, *optional*`, defaults to `None`): The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. resize_mode (`str`, *optional*, defaults to `default`): The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only supported for PIL image input. crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): The crop coordinates for each image in the batch. If `None`, will not crop the image. Returns: [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s) and respective VAE reconstruction(s). """ # Reset attn processor, we do not want to store attn maps during inversion self.unet.set_attn_processor(AttnProcessor()) self.eta = 1.0 self.scheduler.config.timestep_spacing = "leading" self.scheduler.set_timesteps(int(num_inversion_steps * (1 + skip))) self.inversion_steps = self.scheduler.timesteps[-num_inversion_steps:] timesteps = self.inversion_steps # 1. encode image x0, resized = self.encode_image( image, dtype=self.text_encoder.dtype, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords, ) self.batch_size = x0.shape[0] # autoencoder reconstruction image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0] image_rec = self.image_processor.postprocess(image_rec, output_type="pil") # 2. get embeddings do_classifier_free_guidance = source_guidance_scale > 1.0 lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None uncond_embedding, text_embeddings, _ = self.encode_prompt( num_images_per_prompt=1, device=self.device, negative_prompt=None, enable_edit_guidance=do_classifier_free_guidance, editing_prompt=source_prompt, lora_scale=lora_scale, clip_skip=clip_skip, ) # 3. find zs and xts variance_noise_shape = (num_inversion_steps, *x0.shape) # intermediate latents t_to_idx = {int(v): k for k, v in enumerate(timesteps)} xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype) for t in reversed(timesteps): idx = num_inversion_steps - t_to_idx[int(t)] - 1 noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype) xts[idx] = self.scheduler.add_noise(x0, noise, torch.Tensor([t])) xts = torch.cat([x0.unsqueeze(0), xts], dim=0) self.scheduler.set_timesteps(len(self.scheduler.timesteps)) # noise maps zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype) with self.progress_bar(total=len(timesteps)) as progress_bar: for t in timesteps: idx = num_inversion_steps - t_to_idx[int(t)] - 1 # 1. predict noise residual xt = xts[idx + 1] noise_pred = self.unet(xt, timestep=t, encoder_hidden_states=uncond_embedding).sample if not source_prompt == "": noise_pred_cond = self.unet(xt, timestep=t, encoder_hidden_states=text_embeddings).sample noise_pred = noise_pred + source_guidance_scale * (noise_pred_cond - noise_pred) xtm1 = xts[idx] z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, self.eta) zs[idx] = z # correction to avoid error accumulation xts[idx] = xtm1_corrected progress_bar.update() self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1) zs = zs.flip(0) self.zs = zs return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec) @torch.no_grad() def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None): image = self.image_processor.preprocess( image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) resized = self.image_processor.postprocess(image=image, output_type="pil") if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5: logger.warning( "Your input images far exceed the default resolution of the underlying diffusion model. " "The output images may contain severe artifacts! " "Consider down-sampling the input using the `height` and `width` parameters" ) image = image.to(dtype) x0 = self.vae.encode(image.to(self.device)).latent_dist.mode() x0 = x0.to(dtype) x0 = self.vae.config.scaling_factor * x0 return x0, resized def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta): # 1. get previous step value (=t-1) prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps # 2. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = ( scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) # 4. Clip "predicted x_0" if scheduler.config.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = scheduler._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred # modifed so that updated xtm1 is returned as well (to avoid error accumulation) mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if variance > 0.0: noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta) else: noise = torch.tensor([0.0]).to(latents.device) return noise, mu_xt + (eta * variance**0.5) * noise def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta): def first_order_update(model_output, sample): # timestep, prev_timestep, sample): sigma_t, sigma_s = scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index] alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = scheduler._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s mu_xt = (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output mu_xt = scheduler.dpm_solver_first_order_update( model_output=model_output, sample=sample, noise=torch.zeros_like(sample) ) sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) if sigma > 0.0: noise = (prev_latents - mu_xt) / sigma else: noise = torch.tensor([0.0]).to(sample.device) prev_sample = mu_xt + sigma * noise return noise, prev_sample def second_order_update(model_output_list, sample): # timestep_list, prev_timestep, sample): sigma_t, sigma_s0, sigma_s1 = ( scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index], scheduler.sigmas[scheduler.step_index - 1], ) alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = scheduler._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = scheduler._sigma_to_alpha_sigma_t(sigma_s1) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) m0, m1 = model_output_list[-1], model_output_list[-2] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) mu_xt = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 ) sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) if sigma > 0.0: noise = (prev_latents - mu_xt) / sigma else: noise = torch.tensor([0.0]).to(sample.device) prev_sample = mu_xt + sigma * noise return noise, prev_sample if scheduler.step_index is None: scheduler._init_step_index(timestep) model_output = scheduler.convert_model_output(model_output=noise_pred, sample=latents) for i in range(scheduler.config.solver_order - 1): scheduler.model_outputs[i] = scheduler.model_outputs[i + 1] scheduler.model_outputs[-1] = model_output if scheduler.lower_order_nums < 1: noise, prev_sample = first_order_update(model_output, latents) else: noise, prev_sample = second_order_update(scheduler.model_outputs, latents) if scheduler.lower_order_nums < scheduler.config.solver_order: scheduler.lower_order_nums += 1 # upon completion increase step index by one scheduler._step_index += 1 return noise, prev_sample def compute_noise(scheduler, *args): if isinstance(scheduler, DDIMScheduler): return compute_noise_ddim(scheduler, *args) elif ( isinstance(scheduler, DPMSolverMultistepScheduler) and scheduler.config.algorithm_type == "sde-dpmsolver++" and scheduler.config.solver_order == 2 ): return compute_noise_sde_dpm_pp_2nd(scheduler, *args) else: raise NotImplementedError ================================================ FILE: src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( Attention, AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler from ...utils import ( USE_PEFT_BACKEND, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> import PIL >>> import requests >>> from io import BytesIO >>> from diffusers import LEditsPPPipelineStableDiffusionXL >>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg" >>> image = download_image(img_url) >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2) >>> edited_image = pipe( ... editing_prompt=["tennis ball", "tomato"], ... reverse_editing_direction=[True, False], ... edit_guidance_scale=[5.0, 10.0], ... edit_threshold=[0.9, 0.85], ... ).images[0] ``` """ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LeditsAttentionStore class LeditsAttentionStore: @staticmethod def get_empty_store(): return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False): # attn.shape = batch_size * head_size, seq_len query, seq_len_key if attn.shape[1] <= self.max_size: bs = 1 + int(PnP) + editing_prompts skip = 2 if PnP else 1 # skip PnP & unconditional attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3) source_batch_size = int(attn.shape[1] // bs) self.forward(attn[:, skip * source_batch_size :], is_cross, place_in_unet) def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" self.step_store[key].append(attn) def between_steps(self, store_step=True): if store_step: if self.average: if len(self.attention_store) == 0: self.attention_store = self.step_store else: for key in self.attention_store: for i in range(len(self.attention_store[key])): self.attention_store[key][i] += self.step_store[key][i] else: if len(self.attention_store) == 0: self.attention_store = [self.step_store] else: self.attention_store.append(self.step_store) self.cur_step += 1 self.step_store = self.get_empty_store() def get_attention(self, step: int): if self.average: attention = { key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store } else: assert step is not None attention = self.attention_store[step] return attention def aggregate_attention( self, attention_maps, prompts, res: Union[int, Tuple[int]], from_where: List[str], is_cross: bool, select: int ): out = [[] for x in range(self.batch_size)] if isinstance(res, int): num_pixels = res**2 resolution = (res, res) else: num_pixels = res[0] * res[1] resolution = res[:2] for location in from_where: for bs_item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: for batch, item in enumerate(bs_item): if item.shape[1] == num_pixels: cross_maps = item.reshape(len(prompts), -1, *resolution, item.shape[-1])[select] out[batch].append(cross_maps) out = torch.stack([torch.cat(x, dim=0) for x in out]) # average over heads out = out.sum(1) / out.shape[1] return out def __init__(self, average: bool, batch_size=1, max_resolution=16, max_size: int = None): self.step_store = self.get_empty_store() self.attention_store = [] self.cur_step = 0 self.average = average self.batch_size = batch_size if max_size is None: self.max_size = max_resolution**2 elif max_size is not None and max_resolution is None: self.max_size = max_size else: raise ValueError("Only allowed to set one of max_resolution or max_size") # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LeditsGaussianSmoothing class LeditsGaussianSmoothing: def __init__(self, device): kernel_size = [3, 3] sigma = [0.5, 0.5] # The gaussian kernel is the product of the gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1)) self.weight = kernel.to(device) def __call__(self, input): """ Arguments: Apply gaussian filter to input. input (torch.Tensor): Input to apply gaussian filter on. Returns: filtered (torch.Tensor): Filtered output. """ return F.conv2d(input, weight=self.weight.to(input.dtype)) # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEDITSCrossAttnProcessor class LEDITSCrossAttnProcessor: def __init__(self, attention_store, place_in_unet, pnp, editing_prompts): self.attnstore = attention_store self.place_in_unet = place_in_unet self.editing_prompts = editing_prompts self.pnp = pnp def __call__( self, attn: Attention, hidden_states, encoder_hidden_states, attention_mask=None, temb=None, ): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) self.attnstore( attention_probs, is_cross=True, place_in_unet=self.place_in_unet, editing_prompts=self.editing_prompts, PnP=self.pnp, ) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class LEditsPPPipelineStableDiffusionXL( DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, IPAdapterMixin, ): """ Pipeline for textual image editing using LEDits++ with Stable Diffusion XL. This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionXLPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). In addition the pipeline inherits the following loading methods: - *LoRA*: [`LEditsPPPipelineStableDiffusionXL.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer ([`~transformers.CLIPTokenizer`]): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 ([`~transformers.CLIPTokenizer`]): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will automatically be set to [`DPMSolverMultistepScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DPMSolverMultistepScheduler, DDIMScheduler], image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler): self.scheduler = DPMSolverMultistepScheduler.from_config( scheduler.config, algorithm_type="sde-dpmsolver++", solver_order=2 ) logger.warning( "This pipeline only supports DDIMScheduler and DPMSolverMultistepScheduler. " "The scheduler has been changed to DPMSolverMultistepScheduler." ) self.default_sample_size = self.unet.config.sample_size add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None self.inversion_steps = None def encode_prompt( self, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, enable_edit_guidance: bool = True, editing_prompt: Optional[str] = None, editing_prompt_embeds: Optional[torch.Tensor] = None, editing_pooled_prompt_embeds: Optional[torch.Tensor] = None, ) -> object: r""" Encodes the prompt into text encoder hidden states. Args: device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. enable_edit_guidance (`bool`): Whether to guide towards an editing prompt or not. editing_prompt (`str` or `List[str]`, *optional*): Editing prompt(s) to be encoded. If not defined and 'enable_edit_guidance' is True, one has to pass `editing_prompt_embeds` instead. editing_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided and 'enable_edit_guidance' is True, editing_prompt_embeds will be generated from `editing_prompt` input argument. editing_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated edit pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled editing_pooled_prompt_embeds will be generated from `editing_prompt` input argument. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) batch_size = self.batch_size # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) num_edit_tokens = 0 # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but image inversion " f" has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of the input images." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(negative_prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(negative_pooled_prompt_embeds) if enable_edit_guidance and editing_prompt_embeds is None: editing_prompt_2 = editing_prompt editing_prompts = [editing_prompt, editing_prompt_2] edit_prompt_embeds_list = [] for editing_prompt, tokenizer, text_encoder in zip(editing_prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): editing_prompt = self.maybe_convert_prompt(editing_prompt, tokenizer) max_length = negative_prompt_embeds.shape[1] edit_concepts_input = tokenizer( # [x for item in editing_prompt for x in repeat(item, batch_size)], editing_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", return_length=True, ) num_edit_tokens = edit_concepts_input.length - 2 edit_concepts_embeds = text_encoder( edit_concepts_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder editing_pooled_prompt_embeds = edit_concepts_embeds[0] if clip_skip is None: edit_concepts_embeds = edit_concepts_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)] edit_prompt_embeds_list.append(edit_concepts_embeds) edit_concepts_embeds = torch.concat(edit_prompt_embeds_list, dim=-1) elif not enable_edit_guidance: edit_concepts_embeds = None editing_pooled_prompt_embeds = None negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) bs_embed, seq_len, _ = negative_prompt_embeds.shape # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if enable_edit_guidance: bs_embed_edit, seq_len, _ = edit_concepts_embeds.shape edit_concepts_embeds = edit_concepts_embeds.to(dtype=self.text_encoder_2.dtype, device=device) edit_concepts_embeds = edit_concepts_embeds.repeat(1, num_images_per_prompt, 1) edit_concepts_embeds = edit_concepts_embeds.view(bs_embed_edit * num_images_per_prompt, seq_len, -1) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if enable_edit_guidance: editing_pooled_prompt_embeds = editing_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed_edit * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return ( negative_prompt_embeds, edit_concepts_embeds, negative_pooled_prompt_embeds, editing_pooled_prompt_embeds, num_edit_tokens, ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, eta, generator=None): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, negative_prompt=None, negative_prompt_2=None, negative_prompt_embeds=None, negative_pooled_prompt_embeds=None, ): if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, device, latents): latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def num_timesteps(self): return self._num_timesteps # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet def prepare_unet(self, attention_store, PnP: bool = False): attn_procs = {} for name in self.unet.attn_processors.keys(): if name.startswith("mid_block"): place_in_unet = "mid" elif name.startswith("up_blocks"): place_in_unet = "up" elif name.startswith("down_blocks"): place_in_unet = "down" else: continue if "attn2" in name and place_in_unet != "mid": attn_procs[name] = LEDITSCrossAttnProcessor( attention_store=attention_store, place_in_unet=place_in_unet, pnp=PnP, editing_prompts=self.enabled_editing_prompts, ) else: attn_procs[name] = AttnProcessor() self.unet.set_attn_processor(attn_procs) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, denoising_end: Optional[float] = None, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, editing_prompt: Optional[Union[str, List[str]]] = None, editing_prompt_embeddings: Optional[torch.Tensor] = None, editing_pooled_prompt_embeds: Optional[torch.Tensor] = None, reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, edit_guidance_scale: Optional[Union[float, List[float]]] = 5, edit_warmup_steps: Optional[Union[int, List[int]]] = 0, edit_cooldown_steps: Optional[Union[int, List[int]]] = None, edit_threshold: Optional[Union[float, List[float]]] = 0.9, sem_guidance: Optional[List[torch.Tensor]] = None, use_cross_attn_mask: bool = False, use_intersect_mask: bool = False, user_mask: Optional[torch.Tensor] = None, attn_store_steps: Optional[List[int]] = [], store_averaged_over_steps: bool = True, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for editing. The [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusionXL.invert`] method has to be called beforehand. Edits will always be performed for the last inverted image(s). Args: denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). editing_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. The image is reconstructed by setting `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`. editing_prompt_embeddings (`torch.Tensor`, *optional*): Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input argument. editing_pooled_prompt_embeddings (`torch.Tensor`, *optional*): Pre-generated pooled edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input argument. reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): Whether the corresponding prompt in `editing_prompt` should be increased or decreased. edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): Guidance scale for guiding the image generation. If provided as list values should correspond to `editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 12 of [LEDITS++ Paper](https://arxiv.org/abs/2301.12247). edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): Number of diffusion steps (for each prompt) for which guidance is not applied. edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): Number of diffusion steps (for each prompt) after which guidance is no longer applied. edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): Masking threshold of guidance. Threshold should be proportional to the image region that is modified. 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++ Paper](https://arxiv.org/abs/2301.12247). sem_guidance (`List[torch.Tensor]`, *optional*): List of pre-generated guidance vectors to be applied at generation. Length of the list has to correspond to `num_inference_steps`. use_cross_attn_mask: Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf). use_intersect_mask: Whether the masking term is calculated as intersection of cross-attention masks and masks derived from the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise estimate are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf). user_mask: User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s implicit masks do not meet user preferences. attn_store_steps: Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes. store_averaged_over_steps: Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps. If False, attention maps for each step are stores separately. Just for visualization purposes. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] or `tuple`: [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ if self.inversion_steps is None: raise ValueError( "You need to invert an input image first before calling the pipeline. The `invert` method has to be called beforehand. Edits will always be performed for the last inverted image(s)." ) eta = self.eta num_images_per_prompt = 1 latents = self.init_latents zs = self.zs self.scheduler.set_timesteps(len(self.scheduler.timesteps)) if use_intersect_mask: use_cross_attn_mask = True if use_cross_attn_mask: self.smoothing = LeditsGaussianSmoothing(self.device) if user_mask is not None: user_mask = user_mask.to(self.device) # TODO: Check inputs # 1. Check inputs. Raise error if not correct # self.check_inputs( # callback_steps, # negative_prompt, # negative_prompt_2, # prompt_embeds, # negative_prompt_embeds, # pooled_prompt_embeds, # negative_pooled_prompt_embeds, # ) self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end # 2. Define call parameters batch_size = self.batch_size device = self._execution_device if editing_prompt: enable_edit_guidance = True if isinstance(editing_prompt, str): editing_prompt = [editing_prompt] self.enabled_editing_prompts = len(editing_prompt) elif editing_prompt_embeddings is not None: enable_edit_guidance = True self.enabled_editing_prompts = editing_prompt_embeddings.shape[0] else: self.enabled_editing_prompts = 0 enable_edit_guidance = False # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, edit_prompt_embeds, negative_pooled_prompt_embeds, pooled_edit_embeds, num_edit_tokens, ) = self.encode_prompt( device=device, num_images_per_prompt=num_images_per_prompt, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_embeds=negative_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, enable_edit_guidance=enable_edit_guidance, editing_prompt=editing_prompt, editing_prompt_embeds=editing_prompt_embeddings, editing_pooled_prompt_embeds=editing_pooled_prompt_embeds, ) # 4. Prepare timesteps # self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.inversion_steps t_to_idx = {int(v): k for k, v in enumerate(timesteps)} if use_cross_attn_mask: self.attention_store = LeditsAttentionStore( average=store_averaged_over_steps, batch_size=batch_size, max_size=(latents.shape[-2] / 4.0) * (latents.shape[-1] / 4.0), max_resolution=None, ) self.prepare_unet(self.attention_store) resolution = latents.shape[-2:] att_res = (int(resolution[0] / 4), int(resolution[1] / 4)) # 5. Prepare latent variables latents = self.prepare_latents(device=device, latents=latents) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(eta) if self.text_encoder_2 is None: text_encoder_projection_dim = int(negative_pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim # 7. Prepare added time ids & embeddings add_text_embeds = negative_pooled_prompt_embeds add_time_ids = self._get_add_time_ids( self.size, crops_coords_top_left, self.size, dtype=negative_pooled_prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if enable_edit_guidance: prompt_embeds = torch.cat([prompt_embeds, edit_prompt_embeds], dim=0) add_text_embeds = torch.cat([add_text_embeds, pooled_edit_embeds], dim=0) edit_concepts_time_ids = add_time_ids.repeat(edit_prompt_embeds.shape[0], 1) add_time_ids = torch.cat([add_time_ids, edit_concepts_time_ids], dim=0) self.text_cross_attention_maps = [editing_prompt] if isinstance(editing_prompt, str) else editing_prompt prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None: # TODO: fix image encoding image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) # 8. Denoising loop self.sem_guidance = None self.activation_mask = None if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64] noise_pred_uncond = noise_pred_out[0] noise_pred_edit_concepts = noise_pred_out[1:] noise_guidance_edit = torch.zeros( noise_pred_uncond.shape, device=self.device, dtype=noise_pred_uncond.dtype, ) if sem_guidance is not None and len(sem_guidance) > i: noise_guidance_edit += sem_guidance[i].to(self.device) elif enable_edit_guidance: if self.activation_mask is None: self.activation_mask = torch.zeros( (len(timesteps), self.enabled_editing_prompts, *noise_pred_edit_concepts[0].shape) ) if self.sem_guidance is None: self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape)) # noise_guidance_edit = torch.zeros_like(noise_guidance) for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): if isinstance(edit_warmup_steps, list): edit_warmup_steps_c = edit_warmup_steps[c] else: edit_warmup_steps_c = edit_warmup_steps if i < edit_warmup_steps_c: continue if isinstance(edit_guidance_scale, list): edit_guidance_scale_c = edit_guidance_scale[c] else: edit_guidance_scale_c = edit_guidance_scale if isinstance(edit_threshold, list): edit_threshold_c = edit_threshold[c] else: edit_threshold_c = edit_threshold if isinstance(reverse_editing_direction, list): reverse_editing_direction_c = reverse_editing_direction[c] else: reverse_editing_direction_c = reverse_editing_direction if isinstance(edit_cooldown_steps, list): edit_cooldown_steps_c = edit_cooldown_steps[c] elif edit_cooldown_steps is None: edit_cooldown_steps_c = i + 1 else: edit_cooldown_steps_c = edit_cooldown_steps if i >= edit_cooldown_steps_c: continue noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond if reverse_editing_direction_c: noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c if user_mask is not None: noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask if use_cross_attn_mask: out = self.attention_store.aggregate_attention( attention_maps=self.attention_store.step_store, prompts=self.text_cross_attention_maps, res=att_res, from_where=["up", "down"], is_cross=True, select=self.text_cross_attention_maps.index(editing_prompt[c]), ) attn_map = out[:, :, :, 1 : 1 + num_edit_tokens[c]] # 0 -> startoftext # average over all tokens if attn_map.shape[3] != num_edit_tokens[c]: raise ValueError( f"Incorrect shape of attention_map. Expected size {num_edit_tokens[c]}, but found {attn_map.shape[3]}!" ) attn_map = torch.sum(attn_map, dim=3) # gaussian_smoothing attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect") attn_map = self.smoothing(attn_map).squeeze(1) # torch.quantile function expects float32 if attn_map.dtype == torch.float32: tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1) else: tmp = torch.quantile( attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1 ).to(attn_map.dtype) attn_mask = torch.where( attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1, *att_res), 1.0, 0.0 ) # resolution must match latent space dimension attn_mask = F.interpolate( attn_mask.unsqueeze(1), noise_guidance_edit_tmp.shape[-2:], # 64,64 ).repeat(1, 4, 1, 1) self.activation_mask[i, c] = attn_mask.detach().cpu() if not use_intersect_mask: noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask if use_intersect_mask: noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) noise_guidance_edit_tmp_quantile = torch.sum( noise_guidance_edit_tmp_quantile, dim=1, keepdim=True ) noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat( 1, self.unet.config.in_channels, 1, 1 ) # torch.quantile function expects float32 if noise_guidance_edit_tmp_quantile.dtype == torch.float32: tmp = torch.quantile( noise_guidance_edit_tmp_quantile.flatten(start_dim=2), edit_threshold_c, dim=2, keepdim=False, ) else: tmp = torch.quantile( noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), edit_threshold_c, dim=2, keepdim=False, ).to(noise_guidance_edit_tmp_quantile.dtype) intersect_mask = ( torch.where( noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], torch.ones_like(noise_guidance_edit_tmp), torch.zeros_like(noise_guidance_edit_tmp), ) * attn_mask ) self.activation_mask[i, c] = intersect_mask.detach().cpu() noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask elif not use_cross_attn_mask: # calculate quantile noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) noise_guidance_edit_tmp_quantile = torch.sum( noise_guidance_edit_tmp_quantile, dim=1, keepdim=True ) noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1) # torch.quantile function expects float32 if noise_guidance_edit_tmp_quantile.dtype == torch.float32: tmp = torch.quantile( noise_guidance_edit_tmp_quantile.flatten(start_dim=2), edit_threshold_c, dim=2, keepdim=False, ) else: tmp = torch.quantile( noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), edit_threshold_c, dim=2, keepdim=False, ).to(noise_guidance_edit_tmp_quantile.dtype) self.activation_mask[i, c] = ( torch.where( noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], torch.ones_like(noise_guidance_edit_tmp), torch.zeros_like(noise_guidance_edit_tmp), ) .detach() .cpu() ) noise_guidance_edit_tmp = torch.where( noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], noise_guidance_edit_tmp, torch.zeros_like(noise_guidance_edit_tmp), ) noise_guidance_edit += noise_guidance_edit_tmp self.sem_guidance[i] = noise_guidance_edit.detach().cpu() noise_pred = noise_pred_uncond + noise_guidance_edit # compute the previous noisy sample x_t -> x_t-1 if enable_edit_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg( noise_pred, noise_pred_edit_concepts.mean(dim=0, keepdim=False), guidance_rescale=self.guidance_rescale, ) idx = t_to_idx[int(t)] latents = self.scheduler.step( noise_pred, t, latents, variance_noise=zs[idx], **extra_step_kwargs, return_dict=False )[0] # step callback if use_cross_attn_mask: store_step = i in attn_store_steps self.attention_store.between_steps(store_step) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) # negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return LEditsPPDiffusionPipelineOutput(images=image, nsfw_content_detected=None) @torch.no_grad() # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None): image = self.image_processor.preprocess( image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) resized = self.image_processor.postprocess(image=image, output_type="pil") if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5: logger.warning( "Your input images far exceed the default resolution of the underlying diffusion model. " "The output images may contain severe artifacts! " "Consider down-sampling the input using the `height` and `width` parameters" ) image = image.to(self.device, dtype=dtype) needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: image = image.float() self.upcast_vae() x0 = self.vae.encode(image).latent_dist.mode() x0 = x0.to(dtype) # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) x0 = self.vae.config.scaling_factor * x0 return x0, resized @torch.no_grad() def invert( self, image: PipelineImageInput, source_prompt: str = "", source_guidance_scale=3.5, negative_prompt: str = None, negative_prompt_2: str = None, num_inversion_steps: int = 50, skip: float = 0.15, generator: Optional[torch.Generator] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), num_zero_noise_steps: int = 3, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" The function to the pipeline for image inversion as described by the [LEDITS++ Paper](https://arxiv.org/abs/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140) will be performed instead. Args: image (`PipelineImageInput`): Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect ratio. source_prompt (`str`, defaults to `""`): Prompt describing the input image that will be used for guidance during inversion. Guidance is disabled if the `source_prompt` is `""`. source_guidance_scale (`float`, defaults to `3.5`): Strength of guidance during inversion. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_inversion_steps (`int`, defaults to `50`): Number of total performed inversion steps after discarding the initial `skip` steps. skip (`float`, defaults to `0.15`): Portion of initial steps that will be ignored for inversion and subsequent generation. Lower values will lead to stronger changes to the input image. `skip` has to be between `0` and `1`. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make inversion deterministic. crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). num_zero_noise_steps (`int`, defaults to `3`): Number of final diffusion steps that will not renoise the current image. If no steps are set to zero SD-XL in combination with [`DPMSolverMultistepScheduler`] will produce noise artifacts. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Returns: [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s) and respective VAE reconstruction(s). """ # Reset attn processor, we do not want to store attn maps during inversion self.unet.set_attn_processor(AttnProcessor()) self.eta = 1.0 self.scheduler.config.timestep_spacing = "leading" self.scheduler.set_timesteps(int(num_inversion_steps * (1 + skip))) self.inversion_steps = self.scheduler.timesteps[-num_inversion_steps:] timesteps = self.inversion_steps num_images_per_prompt = 1 device = self._execution_device # 0. Ensure that only uncond embedding is used if prompt = "" if source_prompt == "": # noise pred should only be noise_pred_uncond source_guidance_scale = 0.0 do_classifier_free_guidance = False else: do_classifier_free_guidance = source_guidance_scale > 1.0 # 1. prepare image x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype) width = x0.shape[2] * self.vae_scale_factor height = x0.shape[3] * self.vae_scale_factor self.size = (height, width) self.batch_size = x0.shape[0] # 2. get embeddings text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) if isinstance(source_prompt, str): source_prompt = [source_prompt] * self.batch_size ( negative_prompt_embeds, prompt_embeds, negative_pooled_prompt_embeds, edit_pooled_prompt_embeds, _, ) = self.encode_prompt( device=device, num_images_per_prompt=num_images_per_prompt, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, editing_prompt=source_prompt, lora_scale=text_encoder_lora_scale, enable_edit_guidance=do_classifier_free_guidance, ) if self.text_encoder_2 is None: text_encoder_projection_dim = int(negative_pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim # 3. Prepare added time ids & embeddings add_text_embeds = negative_pooled_prompt_embeds add_time_ids = self._get_add_time_ids( self.size, crops_coords_top_left, self.size, dtype=negative_prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if do_classifier_free_guidance: negative_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([add_text_embeds, edit_pooled_prompt_embeds], dim=0) add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) negative_prompt_embeds = negative_prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(self.batch_size * num_images_per_prompt, 1) # autoencoder reconstruction if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() x0_tmp = x0.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image_rec = self.vae.decode( x0_tmp / self.vae.config.scaling_factor, return_dict=False, generator=generator )[0] elif self.vae.config.force_upcast: x0_tmp = x0.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image_rec = self.vae.decode( x0_tmp / self.vae.config.scaling_factor, return_dict=False, generator=generator )[0] else: image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0] image_rec = self.image_processor.postprocess(image_rec, output_type="pil") # 5. find zs and xts variance_noise_shape = (num_inversion_steps, *x0.shape) # intermediate latents t_to_idx = {int(v): k for k, v in enumerate(timesteps)} xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype) for t in reversed(timesteps): idx = num_inversion_steps - t_to_idx[int(t)] - 1 noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype) xts[idx] = self.scheduler.add_noise(x0, noise, t.unsqueeze(0)) xts = torch.cat([x0.unsqueeze(0), xts], dim=0) # noise maps zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype) self.scheduler.set_timesteps(len(self.scheduler.timesteps)) for t in self.progress_bar(timesteps): idx = num_inversion_steps - t_to_idx[int(t)] - 1 # 1. predict noise residual xt = xts[idx + 1] latent_model_input = torch.cat([xt] * 2) if do_classifier_free_guidance else xt latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=negative_prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # 2. perform guidance if do_classifier_free_guidance: noise_pred_out = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] noise_pred = noise_pred_uncond + source_guidance_scale * (noise_pred_text - noise_pred_uncond) xtm1 = xts[idx] z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, self.eta) zs[idx] = z # correction to avoid error accumulation xts[idx] = xtm1_corrected self.init_latents = xts[-1] zs = zs.flip(0) if num_zero_noise_steps > 0: zs[-num_zero_noise_steps:] = torch.zeros_like(zs[-num_zero_noise_steps:]) self.zs = zs return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_ddim def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta): # 1. get previous step value (=t-1) prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps # 2. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = ( scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) # 4. Clip "predicted x_0" if scheduler.config.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = scheduler._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred # modifed so that updated xtm1 is returned as well (to avoid error accumulation) mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if variance > 0.0: noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta) else: noise = torch.tensor([0.0]).to(latents.device) return noise, mu_xt + (eta * variance**0.5) * noise # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_sde_dpm_pp_2nd def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta): def first_order_update(model_output, sample): # timestep, prev_timestep, sample): sigma_t, sigma_s = scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index] alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = scheduler._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s mu_xt = (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output mu_xt = scheduler.dpm_solver_first_order_update( model_output=model_output, sample=sample, noise=torch.zeros_like(sample) ) sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) if sigma > 0.0: noise = (prev_latents - mu_xt) / sigma else: noise = torch.tensor([0.0]).to(sample.device) prev_sample = mu_xt + sigma * noise return noise, prev_sample def second_order_update(model_output_list, sample): # timestep_list, prev_timestep, sample): sigma_t, sigma_s0, sigma_s1 = ( scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index], scheduler.sigmas[scheduler.step_index - 1], ) alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = scheduler._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = scheduler._sigma_to_alpha_sigma_t(sigma_s1) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) m0, m1 = model_output_list[-1], model_output_list[-2] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) mu_xt = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 ) sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) if sigma > 0.0: noise = (prev_latents - mu_xt) / sigma else: noise = torch.tensor([0.0]).to(sample.device) prev_sample = mu_xt + sigma * noise return noise, prev_sample if scheduler.step_index is None: scheduler._init_step_index(timestep) model_output = scheduler.convert_model_output(model_output=noise_pred, sample=latents) for i in range(scheduler.config.solver_order - 1): scheduler.model_outputs[i] = scheduler.model_outputs[i + 1] scheduler.model_outputs[-1] = model_output if scheduler.lower_order_nums < 1: noise, prev_sample = first_order_update(model_output, latents) else: noise, prev_sample = second_order_update(scheduler.model_outputs, latents) if scheduler.lower_order_nums < scheduler.config.solver_order: scheduler.lower_order_nums += 1 # upon completion increase step index by one scheduler._step_index += 1 return noise, prev_sample # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise def compute_noise(scheduler, *args): if isinstance(scheduler, DDIMScheduler): return compute_noise_ddim(scheduler, *args) elif ( isinstance(scheduler, DPMSolverMultistepScheduler) and scheduler.config.algorithm_type == "sde-dpmsolver++" and scheduler.config.solver_order == 2 ): return compute_noise_sde_dpm_pp_2nd(scheduler, *args) else: raise NotImplementedError ================================================ FILE: src/diffusers/pipelines/ledits_pp/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL.Image from ...utils import BaseOutput @dataclass class LEditsPPDiffusionPipelineOutput(BaseOutput): """ Output class for LEdits++ Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. nsfw_content_detected (`List[bool]`) List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] @dataclass class LEditsPPInversionPipelineOutput(BaseOutput): """ Output class for LEdits++ Diffusion pipelines. Args: input_images (`List[PIL.Image.Image]` or `np.ndarray`) List of the cropped and resized input images as PIL images of length `batch_size` or NumPy array of shape ` (batch_size, height, width, num_channels)`. vae_reconstruction_images (`List[PIL.Image.Image]` or `np.ndarray`) List of VAE reconstruction of all input images as PIL images of length `batch_size` or NumPy array of shape ` (batch_size, height, width, num_channels)`. """ images: Union[List[PIL.Image.Image], np.ndarray] vae_reconstruction_images: Union[List[PIL.Image.Image], np.ndarray] ================================================ FILE: src/diffusers/pipelines/lumina/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_lumina"] = ["LuminaText2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_lumina import LuminaText2ImgPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/lumina/pipeline_lumina.py ================================================ # Copyright 2024 Alpha-VLLM and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import html import inspect import math import re import urllib.parse as ul from typing import List, Optional, Tuple, Union import torch from transformers import AutoModel, AutoTokenizer from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL from ...models.embeddings import get_2d_rotary_pos_embed_lumina from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( BACKENDS_MAPPING, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import LuminaText2ImgPipeline >>> pipe = LuminaText2ImgPipeline.from_pretrained( ... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 ... ).cuda() >>> # Enable memory optimizations. >>> pipe.enable_model_cpu_offload() >>> prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures" >>> image = pipe(prompt).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class LuminaText2ImgPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Lumina-T2I. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`AutoModel`]): Frozen text-encoder. Lumina-T2I uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. tokenizer (`AutoModel`): Tokenizer of class [AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel). transformer ([`Transformer2DModel`]): A text conditioned `Transformer2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}" ) # noqa _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" def __init__( self, transformer: LuminaNextDiT2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: AutoModel, tokenizer: AutoTokenizer, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor = 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.max_sequence_length = 256 self.default_sample_size = ( self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128 ) self.default_image_size = self.default_sample_size * self.vae_scale_factor def _get_gemma_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clean_caption: Optional[bool] = False, max_length: Optional[int] = None, ): device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, pad_to_multiple_of=8, max_length=self.max_sequence_length, truncation=True, padding=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.max_sequence_length - 1 : -1]) logger.warning( "The following part of your input was truncated because Gemma can only handle sequences up to" f" {self.max_sequence_length} tokens: {removed_text}" ) prompt_attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True ) prompt_embeds = prompt_embeds.hidden_states[-2] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.transformer is not None: dtype = self.transformer.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, prompt_attention_mask # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, negative_prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, device: Optional[torch.device] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, clean_caption: bool = False, **kwargs, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For Lumina-T2I, this should be "". do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. clean_caption (`bool`, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. """ if device is None: device = self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device, clean_caption=clean_caption, ) # Get negative embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt if negative_prompt is not None else "" # Normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): negative_prompt = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) # Padding negative prompt to the same length with prompt prompt_max_length = prompt_embeds.shape[1] negative_text_inputs = self.tokenizer( negative_prompt, padding="max_length", max_length=prompt_max_length, truncation=True, return_tensors="pt", ) negative_text_input_ids = negative_text_inputs.input_ids.to(device) negative_prompt_attention_mask = negative_text_inputs.attention_mask.to(device) # Get the negative prompt embeddings negative_prompt_embeds = self.text_encoder( negative_text_input_ids, attention_mask=negative_prompt_attention_mask, output_hidden_states=True, ) negative_dtype = self.text_encoder.dtype negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] _, seq_len, _ = negative_prompt_embeds.shape negative_prompt_embeds = negative_prompt_embeds.to(dtype=negative_dtype, device=device) # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) negative_prompt_attention_mask = negative_prompt_attention_mask.view( batch_size * num_images_per_prompt, -1 ) return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, negative_prompt, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: raise ValueError( "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" f" {negative_prompt_attention_mask.shape}." ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) return latents @property def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, width: Optional[int] = None, height: Optional[int] = None, num_inference_steps: int = 30, timesteps: List[int] = None, guidance_scale: float = 4.0, negative_prompt: Union[str, List[str]] = None, sigmas: List[float] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, clean_caption: bool = True, max_sequence_length: int = 256, scaling_watershed: Optional[float] = 1.0, proportional_attn: Optional[bool] = True, ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_inference_steps (`int`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size): The width in pixels of the generated image. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, ) cross_attention_kwargs = {} # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if proportional_attn: cross_attention_kwargs["base_sequence_length"] = (self.default_image_size // 16) ** 2 scaling_factor = math.sqrt(width * height / self.default_image_size**2) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt ( prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( prompt, do_classifier_free_guidance, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, device=device, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, clean_caption=clean_caption, max_sequence_length=max_sequence_length, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents current_timestep = t if not torch.is_tensor(current_timestep): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" if isinstance(current_timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 current_timestep = torch.tensor( [current_timestep], dtype=dtype, device=latent_model_input.device, ) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML current_timestep = current_timestep.expand(latent_model_input.shape[0]) # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps # prepare image_rotary_emb for positional encoding # dynamic scaling_factor for different resolution. # NOTE: For `Time-aware` denosing mechanism from Lumina-Next # https://arxiv.org/abs/2406.18583, Sec 2.3 # NOTE: We should compute different image_rotary_emb with different timestep. if current_timestep[0] < scaling_watershed: linear_factor = scaling_factor ntk_factor = 1.0 else: linear_factor = 1.0 ntk_factor = scaling_factor image_rotary_emb = get_2d_rotary_pos_embed_lumina( self.transformer.head_dim, 384, 384, linear_factor=linear_factor, ntk_factor=ntk_factor, ) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=current_timestep, encoder_hidden_states=prompt_embeds, encoder_mask=prompt_attention_mask, image_rotary_emb=image_rotary_emb, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred.chunk(2, dim=1)[0] # perform guidance scale # NOTE: For exact reproducibility reasons, we apply classifier-free guidance on only # three channels by default. The standard approach to cfg applies it to all channels. # This can be done by uncommenting the following line and commenting-out the line following that. # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] if do_classifier_free_guidance: noise_pred_eps, noise_pred_rest = noise_pred[:, :3], noise_pred[:, 3:] noise_pred_cond_eps, noise_pred_uncond_eps = torch.split( noise_pred_eps, len(noise_pred_eps) // 2, dim=0 ) noise_pred_half = noise_pred_uncond_eps + guidance_scale * ( noise_pred_cond_eps - noise_pred_uncond_eps ) noise_pred_eps = torch.cat([noise_pred_half, noise_pred_half], dim=0) noise_pred = torch.cat([noise_pred_eps, noise_pred_rest], dim=1) noise_pred, _ = noise_pred.chunk(2, dim=0) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype noise_pred = -noise_pred latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) progress_bar.update() if not output_type == "latent": latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) else: image = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/marigold/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["marigold_image_processing"] = ["MarigoldImageProcessor"] _import_structure["pipeline_marigold_depth"] = ["MarigoldDepthOutput", "MarigoldDepthPipeline"] _import_structure["pipeline_marigold_normals"] = ["MarigoldNormalsOutput", "MarigoldNormalsPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .marigold_image_processing import MarigoldImageProcessor from .pipeline_marigold_depth import MarigoldDepthOutput, MarigoldDepthPipeline from .pipeline_marigold_normals import MarigoldNormalsOutput, MarigoldNormalsPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/marigold/marigold_image_processing.py ================================================ from typing import List, Optional, Tuple, Union import numpy as np import PIL import torch import torch.nn.functional as F from PIL import Image from ... import ConfigMixin from ...configuration_utils import register_to_config from ...image_processor import PipelineImageInput from ...utils import CONFIG_NAME, logging from ...utils.import_utils import is_matplotlib_available logger = logging.get_logger(__name__) # pylint: disable=invalid-name class MarigoldImageProcessor(ConfigMixin): config_name = CONFIG_NAME @register_to_config def __init__( self, vae_scale_factor: int = 8, do_normalize: bool = True, do_range_check: bool = True, ): super().__init__() @staticmethod def expand_tensor_or_array(images: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: """ Expand a tensor or array to a specified number of images. """ if isinstance(images, np.ndarray): if images.ndim == 2: # [H,W] -> [1,H,W,1] images = images[None, ..., None] if images.ndim == 3: # [H,W,C] -> [1,H,W,C] images = images[None] elif isinstance(images, torch.Tensor): if images.ndim == 2: # [H,W] -> [1,1,H,W] images = images[None, None] elif images.ndim == 3: # [1,H,W] -> [1,1,H,W] images = images[None] else: raise ValueError(f"Unexpected input type: {type(images)}") return images @staticmethod def pt_to_numpy(images: torch.Tensor) -> np.ndarray: """ Convert a PyTorch tensor to a NumPy image. """ images = images.cpu().permute(0, 2, 3, 1).float().numpy() return images @staticmethod def numpy_to_pt(images: np.ndarray) -> torch.Tensor: """ Convert a NumPy image to a PyTorch tensor. """ if np.issubdtype(images.dtype, np.integer) and not np.issubdtype(images.dtype, np.unsignedinteger): raise ValueError(f"Input image dtype={images.dtype} cannot be a signed integer.") if np.issubdtype(images.dtype, np.complexfloating): raise ValueError(f"Input image dtype={images.dtype} cannot be complex.") if np.issubdtype(images.dtype, bool): raise ValueError(f"Input image dtype={images.dtype} cannot be boolean.") images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images @staticmethod def resize_antialias( image: torch.Tensor, size: Tuple[int, int], mode: str, is_aa: Optional[bool] = None ) -> torch.Tensor: if not torch.is_tensor(image): raise ValueError(f"Invalid input type={type(image)}.") if not torch.is_floating_point(image): raise ValueError(f"Invalid input dtype={image.dtype}.") if image.dim() != 4: raise ValueError(f"Invalid input dimensions; shape={image.shape}.") antialias = is_aa and mode in ("bilinear", "bicubic") image = F.interpolate(image, size, mode=mode, antialias=antialias) return image @staticmethod def resize_to_max_edge(image: torch.Tensor, max_edge_sz: int, mode: str) -> torch.Tensor: if not torch.is_tensor(image): raise ValueError(f"Invalid input type={type(image)}.") if not torch.is_floating_point(image): raise ValueError(f"Invalid input dtype={image.dtype}.") if image.dim() != 4: raise ValueError(f"Invalid input dimensions; shape={image.shape}.") h, w = image.shape[-2:] max_orig = max(h, w) new_h = h * max_edge_sz // max_orig new_w = w * max_edge_sz // max_orig if new_h == 0 or new_w == 0: raise ValueError(f"Extreme aspect ratio of the input image: [{w} x {h}]") image = MarigoldImageProcessor.resize_antialias(image, (new_h, new_w), mode, is_aa=True) return image @staticmethod def pad_image(image: torch.Tensor, align: int) -> Tuple[torch.Tensor, Tuple[int, int]]: if not torch.is_tensor(image): raise ValueError(f"Invalid input type={type(image)}.") if not torch.is_floating_point(image): raise ValueError(f"Invalid input dtype={image.dtype}.") if image.dim() != 4: raise ValueError(f"Invalid input dimensions; shape={image.shape}.") h, w = image.shape[-2:] ph, pw = -h % align, -w % align image = F.pad(image, (0, pw, 0, ph), mode="replicate") return image, (ph, pw) @staticmethod def unpad_image(image: torch.Tensor, padding: Tuple[int, int]) -> torch.Tensor: if not torch.is_tensor(image): raise ValueError(f"Invalid input type={type(image)}.") if not torch.is_floating_point(image): raise ValueError(f"Invalid input dtype={image.dtype}.") if image.dim() != 4: raise ValueError(f"Invalid input dimensions; shape={image.shape}.") ph, pw = padding uh = None if ph == 0 else -ph uw = None if pw == 0 else -pw image = image[:, :, :uh, :uw] return image @staticmethod def load_image_canonical( image: Union[torch.Tensor, np.ndarray, Image.Image], device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float32, ) -> Tuple[torch.Tensor, int]: if isinstance(image, Image.Image): image = np.array(image) image_dtype_max = None if isinstance(image, (np.ndarray, torch.Tensor)): image = MarigoldImageProcessor.expand_tensor_or_array(image) if image.ndim != 4: raise ValueError("Input image is not 2-, 3-, or 4-dimensional.") if isinstance(image, np.ndarray): if np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.unsignedinteger): raise ValueError(f"Input image dtype={image.dtype} cannot be a signed integer.") if np.issubdtype(image.dtype, np.complexfloating): raise ValueError(f"Input image dtype={image.dtype} cannot be complex.") if np.issubdtype(image.dtype, bool): raise ValueError(f"Input image dtype={image.dtype} cannot be boolean.") if np.issubdtype(image.dtype, np.unsignedinteger): image_dtype_max = np.iinfo(image.dtype).max image = image.astype(np.float32) # because torch does not have unsigned dtypes beyond torch.uint8 image = MarigoldImageProcessor.numpy_to_pt(image) if torch.is_tensor(image) and not torch.is_floating_point(image) and image_dtype_max is None: if image.dtype != torch.uint8: raise ValueError(f"Image dtype={image.dtype} is not supported.") image_dtype_max = 255 if not torch.is_tensor(image): raise ValueError(f"Input type unsupported: {type(image)}.") if image.shape[1] == 1: image = image.repeat(1, 3, 1, 1) # [N,1,H,W] -> [N,3,H,W] if image.shape[1] != 3: raise ValueError(f"Input image is not 1- or 3-channel: {image.shape}.") image = image.to(device=device, dtype=dtype) if image_dtype_max is not None: image = image / image_dtype_max return image @staticmethod def check_image_values_range(image: torch.Tensor) -> None: if not torch.is_tensor(image): raise ValueError(f"Invalid input type={type(image)}.") if not torch.is_floating_point(image): raise ValueError(f"Invalid input dtype={image.dtype}.") if image.min().item() < 0.0 or image.max().item() > 1.0: raise ValueError("Input image data is partially outside of the [0,1] range.") def preprocess( self, image: PipelineImageInput, processing_resolution: Optional[int] = None, resample_method_input: str = "bilinear", device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float32, ): if isinstance(image, list): images = None for i, img in enumerate(image): img = self.load_image_canonical(img, device, dtype) # [N,3,H,W] if images is None: images = img else: if images.shape[2:] != img.shape[2:]: raise ValueError( f"Input image[{i}] has incompatible dimensions {img.shape[2:]} with the previous images " f"{images.shape[2:]}" ) images = torch.cat((images, img), dim=0) image = images del images else: image = self.load_image_canonical(image, device, dtype) # [N,3,H,W] original_resolution = image.shape[2:] if self.config.do_range_check: self.check_image_values_range(image) if self.config.do_normalize: image = image * 2.0 - 1.0 if processing_resolution is not None and processing_resolution > 0: image = self.resize_to_max_edge(image, processing_resolution, resample_method_input) # [N,3,PH,PW] image, padding = self.pad_image(image, self.config.vae_scale_factor) # [N,3,PPH,PPW] return image, padding, original_resolution @staticmethod def colormap( image: Union[np.ndarray, torch.Tensor], cmap: str = "Spectral", bytes: bool = False, _force_method: Optional[str] = None, ) -> Union[np.ndarray, torch.Tensor]: """ Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the behavior of matplotlib.colormaps, but allows the user to use the most discriminative color maps ("Spectral", "binary") without having to install or import matplotlib. For all other cases, the function will attempt to use the native implementation. Args: image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor. cmap: Colormap name. bytes: Whether to return the output as uint8 or floating point image. _force_method: Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom implementation of the select color maps (`"custom"`), or rely on autodetection (`None`, default). Returns: An RGB-colorized tensor corresponding to the input image. """ if not (torch.is_tensor(image) or isinstance(image, np.ndarray)): raise ValueError("Argument must be a numpy array or torch tensor.") if _force_method not in (None, "matplotlib", "custom"): raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.") supported_cmaps = { "binary": [ (1.0, 1.0, 1.0), (0.0, 0.0, 0.0), ], "Spectral": [ # Taken from matplotlib/_cm.py (0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0] (0.83529411764705885, 0.24313725490196078, 0.30980392156862746), (0.95686274509803926, 0.42745098039215684, 0.2627450980392157), (0.99215686274509807, 0.68235294117647061, 0.38039215686274508), (0.99607843137254903, 0.8784313725490196, 0.54509803921568623), (1.0, 1.0, 0.74901960784313726), (0.90196078431372551, 0.96078431372549022, 0.59607843137254901), (0.6705882352941176, 0.8666666666666667, 0.64313725490196083), (0.4, 0.76078431372549016, 0.6470588235294118), (0.19607843137254902, 0.53333333333333333, 0.74117647058823533), (0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1] ], } def method_matplotlib(image, cmap, bytes=False): if is_matplotlib_available(): import matplotlib else: return None arg_is_pt, device = torch.is_tensor(image), None if arg_is_pt: image, device = image.cpu().numpy(), image.device if cmap not in matplotlib.colormaps: raise ValueError( f"Unexpected color map {cmap}; available options are: {', '.join(list(matplotlib.colormaps.keys()))}" ) cmap = matplotlib.colormaps[cmap] out = cmap(image, bytes=bytes) # [?,4] out = out[..., :3] # [?,3] if arg_is_pt: out = torch.tensor(out, device=device) return out def method_custom(image, cmap, bytes=False): arg_is_np = isinstance(image, np.ndarray) if arg_is_np: image = torch.tensor(image) if image.dtype == torch.uint8: image = image.float() / 255 else: image = image.float() is_cmap_reversed = cmap.endswith("_r") if is_cmap_reversed: cmap = cmap[:-2] if cmap not in supported_cmaps: raise ValueError( f"Only {list(supported_cmaps.keys())} color maps are available without installing matplotlib." ) cmap = supported_cmaps[cmap] if is_cmap_reversed: cmap = cmap[::-1] cmap = torch.tensor(cmap, dtype=torch.float, device=image.device) # [K,3] K = cmap.shape[0] pos = image.clamp(min=0, max=1) * (K - 1) left = pos.long() right = (left + 1).clamp(max=K - 1) d = (pos - left.float()).unsqueeze(-1) left_colors = cmap[left] right_colors = cmap[right] out = (1 - d) * left_colors + d * right_colors if bytes: out = (out * 255).to(torch.uint8) if arg_is_np: out = out.numpy() return out if _force_method is None and torch.is_tensor(image) and cmap == "Spectral": return method_custom(image, cmap, bytes) out = None if _force_method != "custom": out = method_matplotlib(image, cmap, bytes) if _force_method == "matplotlib" and out is None: raise ImportError("Make sure to install matplotlib if you want to use a color map other than 'Spectral'.") if out is None: out = method_custom(image, cmap, bytes) return out @staticmethod def visualize_depth( depth: Union[ PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], List[torch.Tensor], ], val_min: float = 0.0, val_max: float = 1.0, color_map: str = "Spectral", ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: """ Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`. Args: depth (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], List[torch.Tensor]]`): Depth maps. val_min (`float`, *optional*, defaults to `0.0`): Minimum value of the visualized depth range. val_max (`float`, *optional*, defaults to `1.0`): Maximum value of the visualized depth range. color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel depth prediction into colored representation. Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with depth maps visualization. """ if val_max <= val_min: raise ValueError(f"Invalid values range: [{val_min}, {val_max}].") def visualize_depth_one(img, idx=None): prefix = "Depth" + (f"[{idx}]" if idx else "") if isinstance(img, PIL.Image.Image): if img.mode != "I;16": raise ValueError(f"{prefix}: invalid PIL mode={img.mode}.") img = np.array(img).astype(np.float32) / (2**16 - 1) if isinstance(img, np.ndarray) or torch.is_tensor(img): if img.ndim != 2: raise ValueError(f"{prefix}: unexpected shape={img.shape}.") if isinstance(img, np.ndarray): img = torch.from_numpy(img) if not torch.is_floating_point(img): raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") else: raise ValueError(f"{prefix}: unexpected type={type(img)}.") if val_min != 0.0 or val_max != 1.0: img = (img - val_min) / (val_max - val_min) img = MarigoldImageProcessor.colormap(img, cmap=color_map, bytes=True) # [H,W,3] img = PIL.Image.fromarray(img.cpu().numpy()) return img if depth is None or isinstance(depth, list) and any(o is None for o in depth): raise ValueError("Input depth is `None`") if isinstance(depth, (np.ndarray, torch.Tensor)): depth = MarigoldImageProcessor.expand_tensor_or_array(depth) if isinstance(depth, np.ndarray): depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] if not (depth.ndim == 4 and depth.shape[1] == 1): # [N,1,H,W] raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") return [visualize_depth_one(img[0], idx) for idx, img in enumerate(depth)] elif isinstance(depth, list): return [visualize_depth_one(img, idx) for idx, img in enumerate(depth)] else: raise ValueError(f"Unexpected input type: {type(depth)}") @staticmethod def export_depth_to_16bit_png( depth: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]], val_min: float = 0.0, val_max: float = 1.0, ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: def export_depth_to_16bit_png_one(img, idx=None): prefix = "Depth" + (f"[{idx}]" if idx else "") if not isinstance(img, np.ndarray) and not torch.is_tensor(img): raise ValueError(f"{prefix}: unexpected type={type(img)}.") if img.ndim != 2: raise ValueError(f"{prefix}: unexpected shape={img.shape}.") if torch.is_tensor(img): img = img.cpu().numpy() if not np.issubdtype(img.dtype, np.floating): raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") if val_min != 0.0 or val_max != 1.0: img = (img - val_min) / (val_max - val_min) img = (img * (2**16 - 1)).astype(np.uint16) img = PIL.Image.fromarray(img, mode="I;16") return img if depth is None or isinstance(depth, list) and any(o is None for o in depth): raise ValueError("Input depth is `None`") if isinstance(depth, (np.ndarray, torch.Tensor)): depth = MarigoldImageProcessor.expand_tensor_or_array(depth) if isinstance(depth, np.ndarray): depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] if not (depth.ndim == 4 and depth.shape[1] == 1): raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") return [export_depth_to_16bit_png_one(img[0], idx) for idx, img in enumerate(depth)] elif isinstance(depth, list): return [export_depth_to_16bit_png_one(img, idx) for idx, img in enumerate(depth)] else: raise ValueError(f"Unexpected input type: {type(depth)}") @staticmethod def visualize_normals( normals: Union[ np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor], ], flip_x: bool = False, flip_y: bool = False, flip_z: bool = False, ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: """ Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`. Args: normals (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): Surface normals. flip_x (`bool`, *optional*, defaults to `False`): Flips the X axis of the normals frame of reference. Default direction is right. flip_y (`bool`, *optional*, defaults to `False`): Flips the Y axis of the normals frame of reference. Default direction is top. flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference. Default direction is facing the observer. Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with surface normals visualization. """ flip_vec = None if any((flip_x, flip_y, flip_z)): flip_vec = torch.tensor( [ (-1) ** flip_x, (-1) ** flip_y, (-1) ** flip_z, ], dtype=torch.float32, ) def visualize_normals_one(img, idx=None): img = img.permute(1, 2, 0) if flip_vec is not None: img *= flip_vec.to(img.device) img = (img + 1.0) * 0.5 img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy() img = PIL.Image.fromarray(img) return img if normals is None or isinstance(normals, list) and any(o is None for o in normals): raise ValueError("Input normals is `None`") if isinstance(normals, (np.ndarray, torch.Tensor)): normals = MarigoldImageProcessor.expand_tensor_or_array(normals) if isinstance(normals, np.ndarray): normals = MarigoldImageProcessor.numpy_to_pt(normals) # [N,3,H,W] if not (normals.ndim == 4 and normals.shape[1] == 3): raise ValueError(f"Unexpected input shape={normals.shape}, expecting [N,3,H,W].") return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] elif isinstance(normals, list): return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] else: raise ValueError(f"Unexpected input type: {type(normals)}") @staticmethod def visualize_uncertainty( uncertainty: Union[ np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor], ], saturation_percentile=95, ) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: """ Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`. Args: uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): Uncertainty maps. saturation_percentile (`int`, *optional*, defaults to `95`): Specifies the percentile uncertainty value visualized with maximum intensity. Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with uncertainty visualization. """ def visualize_uncertainty_one(img, idx=None): prefix = "Uncertainty" + (f"[{idx}]" if idx else "") if img.min() < 0: raise ValueError(f"{prefix}: unexected data range, min={img.min()}.") img = img.squeeze(0).cpu().numpy() saturation_value = np.percentile(img, saturation_percentile) img = np.clip(img * 255 / saturation_value, 0, 255) img = img.astype(np.uint8) img = PIL.Image.fromarray(img) return img if uncertainty is None or isinstance(uncertainty, list) and any(o is None for o in uncertainty): raise ValueError("Input uncertainty is `None`") if isinstance(uncertainty, (np.ndarray, torch.Tensor)): uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty) if isinstance(uncertainty, np.ndarray): uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,1,H,W] if not (uncertainty.ndim == 4 and uncertainty.shape[1] == 1): raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,1,H,W].") return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] elif isinstance(uncertainty, list): return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] else: raise ValueError(f"Unexpected input type: {type(uncertainty)}") ================================================ FILE: src/diffusers/pipelines/marigold/pipeline_marigold_depth.py ================================================ # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # More information and citation instructions are available on the # Marigold project website: https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- from dataclasses import dataclass from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from ...image_processor import PipelineImageInput from ...models import ( AutoencoderKL, UNet2DConditionModel, ) from ...schedulers import ( DDIMScheduler, LCMScheduler, ) from ...utils import ( BaseOutput, logging, replace_example_docstring, ) from ...utils.import_utils import is_scipy_available from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .marigold_image_processing import MarigoldImageProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import diffusers >>> import torch >>> pipe = diffusers.MarigoldDepthPipeline.from_pretrained( ... "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16 ... ).to("cuda") >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") >>> depth = pipe(image) >>> vis = pipe.image_processor.visualize_depth(depth.prediction) >>> vis[0].save("einstein_depth.png") >>> depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction) >>> depth_16bit[0].save("einstein_depth_16bit.png") ``` """ @dataclass class MarigoldDepthOutput(BaseOutput): """ Output class for Marigold monocular depth prediction pipeline. Args: prediction (`np.ndarray`, `torch.Tensor`): Predicted depth maps with values in the range [0, 1]. The shape is always $numimages \times 1 \times height \times width$, regardless of whether the images were passed as a 4D array or a list. uncertainty (`None`, `np.ndarray`, `torch.Tensor`): Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages \times 1 \times height \times width$. latent (`None`, `torch.Tensor`): Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. """ prediction: Union[np.ndarray, torch.Tensor] uncertainty: Union[None, np.ndarray, torch.Tensor] latent: Union[None, torch.Tensor] class MarigoldDepthPipeline(DiffusionPipeline): """ Pipeline for monocular depth estimation using the Marigold method: https://marigoldmonodepth.github.io. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: unet (`UNet2DConditionModel`): Conditional U-Net to denoise the depth latent, conditioned on image latent. vae (`AutoencoderKL`): Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent representations. scheduler (`DDIMScheduler` or `LCMScheduler`): A scheduler to be used in combination with `unet` to denoise the encoded image latents. text_encoder (`CLIPTextModel`): Text-encoder, for empty text embedding. tokenizer (`CLIPTokenizer`): CLIP tokenizer. prediction_type (`str`, *optional*): Type of predictions made by the model. scale_invariant (`bool`, *optional*): A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in the model config. When used together with the `shift_invariant=True` flag, the model is also called "affine-invariant". NB: overriding this value is not supported. shift_invariant (`bool`, *optional*): A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in the model config. When used together with the `scale_invariant=True` flag, the model is also called "affine-invariant". NB: overriding this value is not supported. default_denoising_steps (`int`, *optional*): The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable quality with the given model. This value must be set in the model config. When the pipeline is called without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure reasonable results with various model flavors compatible with the pipeline, such as those relying on very short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). default_processing_resolution (`int`, *optional*): The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in the model config. When the pipeline is called without explicitly setting `processing_resolution`, the default value is used. This is required to ensure reasonable results with various model flavors trained with varying optimal processing resolution values. """ model_cpu_offload_seq = "text_encoder->unet->vae" supported_prediction_types = ("depth", "disparity") def __init__( self, unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: Union[DDIMScheduler, LCMScheduler], text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, prediction_type: Optional[str] = None, scale_invariant: Optional[bool] = True, shift_invariant: Optional[bool] = True, default_denoising_steps: Optional[int] = None, default_processing_resolution: Optional[int] = None, ): super().__init__() if prediction_type not in self.supported_prediction_types: logger.warning( f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: " f"{self.supported_prediction_types}." ) self.register_modules( unet=unet, vae=vae, scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, ) self.register_to_config( prediction_type=prediction_type, scale_invariant=scale_invariant, shift_invariant=shift_invariant, default_denoising_steps=default_denoising_steps, default_processing_resolution=default_processing_resolution, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.scale_invariant = scale_invariant self.shift_invariant = shift_invariant self.default_denoising_steps = default_denoising_steps self.default_processing_resolution = default_processing_resolution self.empty_text_embedding = None self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) def check_inputs( self, image: PipelineImageInput, num_inference_steps: int, ensemble_size: int, processing_resolution: int, resample_method_input: str, resample_method_output: str, batch_size: int, ensembling_kwargs: Optional[Dict[str, Any]], latents: Optional[torch.Tensor], generator: Optional[Union[torch.Generator, List[torch.Generator]]], output_type: str, output_uncertainty: bool, ) -> int: if num_inference_steps is None: raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") if num_inference_steps < 1: raise ValueError("`num_inference_steps` must be positive.") if ensemble_size < 1: raise ValueError("`ensemble_size` must be positive.") if ensemble_size == 2: logger.warning( "`ensemble_size` == 2 results are similar to no ensembling (1); " "consider increasing the value to at least 3." ) if ensemble_size > 1 and (self.scale_invariant or self.shift_invariant) and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use ensembling.") if ensemble_size == 1 and output_uncertainty: raise ValueError( "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " "greater than 1." ) if processing_resolution is None: raise ValueError( "`processing_resolution` is not specified and could not be resolved from the model config." ) if processing_resolution < 0: raise ValueError( "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " "downsampled processing." ) if processing_resolution % self.vae_scale_factor != 0: raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): raise ValueError( "`resample_method_input` takes string values compatible with PIL library: " "nearest, nearest-exact, bilinear, bicubic, area." ) if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): raise ValueError( "`resample_method_output` takes string values compatible with PIL library: " "nearest, nearest-exact, bilinear, bicubic, area." ) if batch_size < 1: raise ValueError("`batch_size` must be positive.") if output_type not in ["pt", "np"]: raise ValueError("`output_type` must be one of `pt` or `np`.") if latents is not None and generator is not None: raise ValueError("`latents` and `generator` cannot be used together.") if ensembling_kwargs is not None: if not isinstance(ensembling_kwargs, dict): raise ValueError("`ensembling_kwargs` must be a dictionary.") if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("mean", "median"): raise ValueError("`ensembling_kwargs['reduction']` can be either `'mean'` or `'median'`.") # image checks num_images = 0 W, H = None, None if not isinstance(image, list): image = [image] for i, img in enumerate(image): if isinstance(img, np.ndarray) or torch.is_tensor(img): if img.ndim not in (2, 3, 4): raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") H_i, W_i = img.shape[-2:] N_i = 1 if img.ndim == 4: N_i = img.shape[0] elif isinstance(img, Image.Image): W_i, H_i = img.size N_i = 1 else: raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") if W is None: W, H = W_i, H_i elif (W, H) != (W_i, H_i): raise ValueError( f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" ) num_images += N_i # latents checks if latents is not None: if not torch.is_tensor(latents): raise ValueError("`latents` must be a torch.Tensor.") if latents.dim() != 4: raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") if processing_resolution > 0: max_orig = max(H, W) new_H = H * processing_resolution // max_orig new_W = W * processing_resolution // max_orig if new_H == 0 or new_W == 0: raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") W, H = new_W, new_H w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) if latents.shape != shape_expected: raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") # generator checks if generator is not None: if isinstance(generator, list): if len(generator) != num_images * ensemble_size: raise ValueError( "The number of generators must match the total number of ensemble members for all input images." ) if not all(g.device.type == generator[0].device.type for g in generator): raise ValueError("`generator` device placement is not consistent in the list.") elif not isinstance(generator, torch.Generator): raise ValueError(f"Unsupported generator type: {type(generator)}.") return num_images def progress_bar(self, iterable=None, total=None, desc=None, leave=True): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): raise ValueError( f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) progress_bar_config = dict(**self._progress_bar_config) progress_bar_config["desc"] = progress_bar_config.get("desc", desc) progress_bar_config["leave"] = progress_bar_config.get("leave", leave) if iterable is not None: return tqdm(iterable, **progress_bar_config) elif total is not None: return tqdm(total=total, **progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: PipelineImageInput, num_inference_steps: Optional[int] = None, ensemble_size: int = 1, processing_resolution: Optional[int] = None, match_input_resolution: bool = True, resample_method_input: str = "bilinear", resample_method_output: str = "bilinear", batch_size: int = 1, ensembling_kwargs: Optional[Dict[str, Any]] = None, latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: str = "np", output_uncertainty: bool = False, output_latent: bool = False, return_dict: bool = True, ): """ Function invoked when calling the pipeline. Args: image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), `List[torch.Tensor]`: An input image or images used as an input for the depth estimation task. For arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the same width and height. num_inference_steps (`int`, *optional*, defaults to `None`): Number of denoising diffusion steps during inference. The default value `None` results in automatic selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 for Marigold-LCM models. ensemble_size (`int`, defaults to `1`): Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for faster inference. processing_resolution (`int`, *optional*, defaults to `None`): Effective processing resolution. When set to `0`, matches the larger input image dimension. This produces crisper predictions, but may also lead to the overall loss of global context. The default value `None` resolves to the optimal value from the model config. match_input_resolution (`bool`, *optional*, defaults to `True`): When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer side of the output will equal to `processing_resolution`. resample_method_input (`str`, *optional*, defaults to `"bilinear"`): Resampling method used to resize input images to `processing_resolution`. The accepted values are: `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. resample_method_output (`str`, *optional*, defaults to `"bilinear"`): Resampling method used to resize output predictions to match the input resolution. The accepted values are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. batch_size (`int`, *optional*, defaults to `1`): Batch size; only matters when setting `ensemble_size` or passing a tensor of images. ensembling_kwargs (`dict`, *optional*, defaults to `None`) Extra dictionary with arguments for precise ensembling control. The following options are available: - reduction (`str`, *optional*, defaults to `"median"`): Defines the ensembling function applied in every pixel location, can be either `"median"` or `"mean"`. - regularizer_strength (`float`, *optional*, defaults to `0.02`): Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. - max_iter (`int`, *optional*, defaults to `2`): Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` argument. - tol (`float`, *optional*, defaults to `1e-3`): Alignment solver tolerance. The solver stops when the tolerance is reached. - max_res (`int`, *optional*, defaults to `None`): Resolution at which the alignment is performed; `None` matches the `processing_resolution`. latents (`torch.Tensor`, or `List[torch.Tensor]`, *optional*, defaults to `None`): Latent noise tensors to replace the random initialization. These can be taken from the previous function call's output. generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): Random number generator object to ensure reproducibility. output_type (`str`, *optional*, defaults to `"np"`): Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted values are: `"np"` (numpy array) or `"pt"` (torch tensor). output_uncertainty (`bool`, *optional*, defaults to `False`): When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that the `ensemble_size` argument is set to a value above 2. output_latent (`bool`, *optional*, defaults to `False`): When enabled, the output's `latent` field contains the latent codes corresponding to the predictions within the ensemble. These codes can be saved, modified, and used for subsequent calls with the `latents` argument. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.marigold.MarigoldDepthOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.marigold.MarigoldDepthOutput`] is returned, otherwise a `tuple` is returned where the first element is the prediction, the second element is the uncertainty (or `None`), and the third is the latent (or `None`). """ # 0. Resolving variables. device = self._execution_device dtype = self.dtype # Model-specific optimal default values leading to fast and reasonable results. if num_inference_steps is None: num_inference_steps = self.default_denoising_steps if processing_resolution is None: processing_resolution = self.default_processing_resolution # 1. Check inputs. num_images = self.check_inputs( image, num_inference_steps, ensemble_size, processing_resolution, resample_method_input, resample_method_output, batch_size, ensembling_kwargs, latents, generator, output_type, output_uncertainty, ) # 2. Prepare empty text conditioning. # Model invocation: self.tokenizer, self.text_encoder. if self.empty_text_embedding is None: prompt = "" text_inputs = self.tokenizer( prompt, padding="do_not_pad", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of # operation and leads to the most reasonable results. Using the native image resolution or any other processing # resolution can lead to loss of either fine details or global context in the output predictions. image, padding, original_resolution = self.image_processor.preprocess( image, processing_resolution, resample_method_input, device, dtype ) # [N,3,PPH,PPW] # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. # Latents of each such predictions across all input images and all ensemble members are represented in the # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. # Model invocation: self.vae.encoder. image_latent, pred_latent = self.prepare_latents( image, latents, generator, ensemble_size, batch_size ) # [N*E,4,h,w], [N*E,4,h,w] del image batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat( batch_size, 1, 1 ) # [B,1024,2] # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded # model. # Model invocation: self.unet. pred_latents = [] for i in self.progress_bar( range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..." ): batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w] effective_batch_size = batch_image_latent.shape[0] text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] self.scheduler.set_timesteps(num_inference_steps, device=device) for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."): batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w] noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w] batch_pred_latent = self.scheduler.step( noise, t, batch_pred_latent, generator=generator ).prev_sample # [B,4,h,w] pred_latents.append(batch_pred_latent) pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] del ( pred_latents, image_latent, batch_empty_text_embedding, batch_image_latent, batch_pred_latent, text, batch_latent, noise, ) # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. # Model invocation: self.vae.decoder. prediction = torch.cat( [ self.decode_prediction(pred_latent[i : i + batch_size]) for i in range(0, pred_latent.shape[0], batch_size) ], dim=0, ) # [N*E,1,PPH,PPW] if not output_latent: pred_latent = None # 7. Remove padding. The output shape is (PH, PW). prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,1,PH,PW] # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N` # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for # each group independently, it stacks them respectively into batches of `N` almost final predictions and # uncertainty maps. uncertainty = None if ensemble_size > 1: prediction = prediction.reshape(num_images, ensemble_size, *prediction.shape[1:]) # [N,E,1,PH,PW] prediction = [ self.ensemble_depth( prediction[i], self.scale_invariant, self.shift_invariant, output_uncertainty, **(ensembling_kwargs or {}), ) for i in range(num_images) ] # [ [[1,1,PH,PW], [1,1,PH,PW]], ... ] prediction, uncertainty = zip(*prediction) # [[1,1,PH,PW], ... ], [[1,1,PH,PW], ... ] prediction = torch.cat(prediction, dim=0) # [N,1,PH,PW] if output_uncertainty: uncertainty = torch.cat(uncertainty, dim=0) # [N,1,PH,PW] else: uncertainty = None # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by # setting the `resample_method_output` parameter (e.g., to `"nearest"`). if match_input_resolution: prediction = self.image_processor.resize_antialias( prediction, original_resolution, resample_method_output, is_aa=False ) # [N,1,H,W] if uncertainty is not None and output_uncertainty: uncertainty = self.image_processor.resize_antialias( uncertainty, original_resolution, resample_method_output, is_aa=False ) # [N,1,H,W] # 10. Prepare the final outputs. if output_type == "np": prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,1] if uncertainty is not None and output_uncertainty: uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N,H,W,1] # 11. Offload all models self.maybe_free_model_hooks() if not return_dict: return (prediction, uncertainty, pred_latent) return MarigoldDepthOutput( prediction=prediction, uncertainty=uncertainty, latent=pred_latent, ) def prepare_latents( self, image: torch.Tensor, latents: Optional[torch.Tensor], generator: Optional[torch.Generator], ensemble_size: int, batch_size: int, ) -> Tuple[torch.Tensor, torch.Tensor]: def retrieve_latents(encoder_output): if hasattr(encoder_output, "latent_dist"): return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") image_latent = torch.cat( [ retrieve_latents(self.vae.encode(image[i : i + batch_size])) for i in range(0, image.shape[0], batch_size) ], dim=0, ) # [N,4,h,w] image_latent = image_latent * self.vae.config.scaling_factor image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] pred_latent = latents if pred_latent is None: pred_latent = randn_tensor( image_latent.shape, generator=generator, device=image_latent.device, dtype=image_latent.dtype, ) # [N*E,4,h,w] return image_latent, pred_latent def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: raise ValueError( f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." ) prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] prediction = prediction.mean(dim=1, keepdim=True) # [B,1,H,W] prediction = torch.clip(prediction, -1.0, 1.0) # [B,1,H,W] prediction = (prediction + 1.0) / 2.0 return prediction # [B,1,H,W] @staticmethod def ensemble_depth( depth: torch.Tensor, scale_invariant: bool = True, shift_invariant: bool = True, output_uncertainty: bool = False, reduction: str = "median", regularizer_strength: float = 0.02, max_iter: int = 2, tol: float = 1e-3, max_res: int = 1024, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Ensembles the depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The alignment happens when the predictions have one or more degrees of freedom, that is when they are either affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) alignment is skipped and only ensembling is performed. Args: depth (`torch.Tensor`): Input ensemble depth maps. scale_invariant (`bool`, *optional*, defaults to `True`): Whether to treat predictions as scale-invariant. shift_invariant (`bool`, *optional*, defaults to `True`): Whether to treat predictions as shift-invariant. output_uncertainty (`bool`, *optional*, defaults to `False`): Whether to output uncertainty map. reduction (`str`, *optional*, defaults to `"median"`): Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and `"median"`. regularizer_strength (`float`, *optional*, defaults to `0.02`): Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. max_iter (`int`, *optional*, defaults to `2`): Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` argument. tol (`float`, *optional*, defaults to `1e-3`): Alignment solver tolerance. The solver stops when the tolerance is reached. max_res (`int`, *optional*, defaults to `1024`): Resolution at which the alignment is performed; `None` matches the `processing_resolution`. Returns: A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: `(1, 1, H, W)`. """ if depth.dim() != 4 or depth.shape[1] != 1: raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") if reduction not in ("mean", "median"): raise ValueError(f"Unrecognized reduction method: {reduction}.") if not scale_invariant and shift_invariant: raise ValueError("Pure shift-invariant ensembling is not supported.") def init_param(depth: torch.Tensor): init_min = depth.reshape(ensemble_size, -1).min(dim=1).values init_max = depth.reshape(ensemble_size, -1).max(dim=1).values if scale_invariant and shift_invariant: init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) init_t = -init_s * init_min param = torch.cat((init_s, init_t)).cpu().numpy() elif scale_invariant: init_s = 1.0 / init_max.clamp(min=1e-6) param = init_s.cpu().numpy() else: raise ValueError("Unrecognized alignment.") return param def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: if scale_invariant and shift_invariant: s, t = np.split(param, 2) s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) out = depth * s + t elif scale_invariant: s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) out = depth * s else: raise ValueError("Unrecognized alignment.") return out def ensemble( depth_aligned: torch.Tensor, return_uncertainty: bool = False ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: uncertainty = None if reduction == "mean": prediction = torch.mean(depth_aligned, dim=0, keepdim=True) if return_uncertainty: uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) elif reduction == "median": prediction = torch.median(depth_aligned, dim=0, keepdim=True).values if return_uncertainty: uncertainty = torch.median(torch.abs(depth_aligned - prediction), dim=0, keepdim=True).values else: raise ValueError(f"Unrecognized reduction method: {reduction}.") return prediction, uncertainty def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: cost = 0.0 depth_aligned = align(depth, param) for i, j in torch.combinations(torch.arange(ensemble_size)): diff = depth_aligned[i] - depth_aligned[j] cost += (diff**2).mean().sqrt().item() if regularizer_strength > 0: prediction, _ = ensemble(depth_aligned, return_uncertainty=False) err_near = (0.0 - prediction.min()).abs().item() err_far = (1.0 - prediction.max()).abs().item() cost += (err_near + err_far) * regularizer_strength return cost def compute_param(depth: torch.Tensor): import scipy depth_to_align = depth.to(torch.float32) if max_res is not None and max(depth_to_align.shape[2:]) > max_res: depth_to_align = MarigoldImageProcessor.resize_to_max_edge(depth_to_align, max_res, "nearest-exact") param = init_param(depth_to_align) res = scipy.optimize.minimize( partial(cost_fn, depth=depth_to_align), param, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}, ) return res.x requires_aligning = scale_invariant or shift_invariant ensemble_size = depth.shape[0] if requires_aligning: param = compute_param(depth) depth = align(depth, param) depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) depth_max = depth.max() if scale_invariant and shift_invariant: depth_min = depth.min() elif scale_invariant: depth_min = 0 else: raise ValueError("Unrecognized alignment.") depth_range = (depth_max - depth_min).clamp(min=1e-6) depth = (depth - depth_min) / depth_range if output_uncertainty: uncertainty /= depth_range return depth, uncertainty # [1,1,H,W], [1,1,H,W] ================================================ FILE: src/diffusers/pipelines/marigold/pipeline_marigold_normals.py ================================================ # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved. # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # More information and citation instructions are available on the # Marigold project website: https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from ...image_processor import PipelineImageInput from ...models import ( AutoencoderKL, UNet2DConditionModel, ) from ...schedulers import ( DDIMScheduler, LCMScheduler, ) from ...utils import ( BaseOutput, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .marigold_image_processing import MarigoldImageProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import diffusers >>> import torch >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16 ... ).to("cuda") >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") >>> normals = pipe(image) >>> vis = pipe.image_processor.visualize_normals(normals.prediction) >>> vis[0].save("einstein_normals.png") ``` """ @dataclass class MarigoldNormalsOutput(BaseOutput): """ Output class for Marigold monocular normals prediction pipeline. Args: prediction (`np.ndarray`, `torch.Tensor`): Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height \times width$, regardless of whether the images were passed as a 4D array or a list. uncertainty (`None`, `np.ndarray`, `torch.Tensor`): Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages \times 1 \times height \times width$. latent (`None`, `torch.Tensor`): Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. """ prediction: Union[np.ndarray, torch.Tensor] uncertainty: Union[None, np.ndarray, torch.Tensor] latent: Union[None, torch.Tensor] class MarigoldNormalsPipeline(DiffusionPipeline): """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: unet (`UNet2DConditionModel`): Conditional U-Net to denoise the normals latent, conditioned on image latent. vae (`AutoencoderKL`): Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent representations. scheduler (`DDIMScheduler` or `LCMScheduler`): A scheduler to be used in combination with `unet` to denoise the encoded image latents. text_encoder (`CLIPTextModel`): Text-encoder, for empty text embedding. tokenizer (`CLIPTokenizer`): CLIP tokenizer. prediction_type (`str`, *optional*): Type of predictions made by the model. use_full_z_range (`bool`, *optional*): Whether the normals predicted by this model utilize the full range of the Z dimension, or only its positive half. default_denoising_steps (`int`, *optional*): The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable quality with the given model. This value must be set in the model config. When the pipeline is called without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure reasonable results with various model flavors compatible with the pipeline, such as those relying on very short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). default_processing_resolution (`int`, *optional*): The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in the model config. When the pipeline is called without explicitly setting `processing_resolution`, the default value is used. This is required to ensure reasonable results with various model flavors trained with varying optimal processing resolution values. """ model_cpu_offload_seq = "text_encoder->unet->vae" supported_prediction_types = ("normals",) def __init__( self, unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: Union[DDIMScheduler, LCMScheduler], text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, prediction_type: Optional[str] = None, use_full_z_range: Optional[bool] = True, default_denoising_steps: Optional[int] = None, default_processing_resolution: Optional[int] = None, ): super().__init__() if prediction_type not in self.supported_prediction_types: logger.warning( f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: " f"{self.supported_prediction_types}." ) self.register_modules( unet=unet, vae=vae, scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, ) self.register_to_config( use_full_z_range=use_full_z_range, default_denoising_steps=default_denoising_steps, default_processing_resolution=default_processing_resolution, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.use_full_z_range = use_full_z_range self.default_denoising_steps = default_denoising_steps self.default_processing_resolution = default_processing_resolution self.empty_text_embedding = None self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) def check_inputs( self, image: PipelineImageInput, num_inference_steps: int, ensemble_size: int, processing_resolution: int, resample_method_input: str, resample_method_output: str, batch_size: int, ensembling_kwargs: Optional[Dict[str, Any]], latents: Optional[torch.Tensor], generator: Optional[Union[torch.Generator, List[torch.Generator]]], output_type: str, output_uncertainty: bool, ) -> int: if num_inference_steps is None: raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") if num_inference_steps < 1: raise ValueError("`num_inference_steps` must be positive.") if ensemble_size < 1: raise ValueError("`ensemble_size` must be positive.") if ensemble_size == 2: logger.warning( "`ensemble_size` == 2 results are similar to no ensembling (1); " "consider increasing the value to at least 3." ) if ensemble_size == 1 and output_uncertainty: raise ValueError( "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " "greater than 1." ) if processing_resolution is None: raise ValueError( "`processing_resolution` is not specified and could not be resolved from the model config." ) if processing_resolution < 0: raise ValueError( "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " "downsampled processing." ) if processing_resolution % self.vae_scale_factor != 0: raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): raise ValueError( "`resample_method_input` takes string values compatible with PIL library: " "nearest, nearest-exact, bilinear, bicubic, area." ) if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): raise ValueError( "`resample_method_output` takes string values compatible with PIL library: " "nearest, nearest-exact, bilinear, bicubic, area." ) if batch_size < 1: raise ValueError("`batch_size` must be positive.") if output_type not in ["pt", "np"]: raise ValueError("`output_type` must be one of `pt` or `np`.") if latents is not None and generator is not None: raise ValueError("`latents` and `generator` cannot be used together.") if ensembling_kwargs is not None: if not isinstance(ensembling_kwargs, dict): raise ValueError("`ensembling_kwargs` must be a dictionary.") if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") # image checks num_images = 0 W, H = None, None if not isinstance(image, list): image = [image] for i, img in enumerate(image): if isinstance(img, np.ndarray) or torch.is_tensor(img): if img.ndim not in (2, 3, 4): raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") H_i, W_i = img.shape[-2:] N_i = 1 if img.ndim == 4: N_i = img.shape[0] elif isinstance(img, Image.Image): W_i, H_i = img.size N_i = 1 else: raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") if W is None: W, H = W_i, H_i elif (W, H) != (W_i, H_i): raise ValueError( f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" ) num_images += N_i # latents checks if latents is not None: if not torch.is_tensor(latents): raise ValueError("`latents` must be a torch.Tensor.") if latents.dim() != 4: raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") if processing_resolution > 0: max_orig = max(H, W) new_H = H * processing_resolution // max_orig new_W = W * processing_resolution // max_orig if new_H == 0 or new_W == 0: raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") W, H = new_W, new_H w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) if latents.shape != shape_expected: raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") # generator checks if generator is not None: if isinstance(generator, list): if len(generator) != num_images * ensemble_size: raise ValueError( "The number of generators must match the total number of ensemble members for all input images." ) if not all(g.device.type == generator[0].device.type for g in generator): raise ValueError("`generator` device placement is not consistent in the list.") elif not isinstance(generator, torch.Generator): raise ValueError(f"Unsupported generator type: {type(generator)}.") return num_images def progress_bar(self, iterable=None, total=None, desc=None, leave=True): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): raise ValueError( f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) progress_bar_config = dict(**self._progress_bar_config) progress_bar_config["desc"] = progress_bar_config.get("desc", desc) progress_bar_config["leave"] = progress_bar_config.get("leave", leave) if iterable is not None: return tqdm(iterable, **progress_bar_config) elif total is not None: return tqdm(total=total, **progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: PipelineImageInput, num_inference_steps: Optional[int] = None, ensemble_size: int = 1, processing_resolution: Optional[int] = None, match_input_resolution: bool = True, resample_method_input: str = "bilinear", resample_method_output: str = "bilinear", batch_size: int = 1, ensembling_kwargs: Optional[Dict[str, Any]] = None, latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: str = "np", output_uncertainty: bool = False, output_latent: bool = False, return_dict: bool = True, ): """ Function invoked when calling the pipeline. Args: image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the same width and height. num_inference_steps (`int`, *optional*, defaults to `None`): Number of denoising diffusion steps during inference. The default value `None` results in automatic selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 for Marigold-LCM models. ensemble_size (`int`, defaults to `1`): Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for faster inference. processing_resolution (`int`, *optional*, defaults to `None`): Effective processing resolution. When set to `0`, matches the larger input image dimension. This produces crisper predictions, but may also lead to the overall loss of global context. The default value `None` resolves to the optimal value from the model config. match_input_resolution (`bool`, *optional*, defaults to `True`): When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer side of the output will equal to `processing_resolution`. resample_method_input (`str`, *optional*, defaults to `"bilinear"`): Resampling method used to resize input images to `processing_resolution`. The accepted values are: `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. resample_method_output (`str`, *optional*, defaults to `"bilinear"`): Resampling method used to resize output predictions to match the input resolution. The accepted values are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. batch_size (`int`, *optional*, defaults to `1`): Batch size; only matters when setting `ensemble_size` or passing a tensor of images. ensembling_kwargs (`dict`, *optional*, defaults to `None`) Extra dictionary with arguments for precise ensembling control. The following options are available: - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in every pixel location, can be either `"closest"` or `"mean"`. latents (`torch.Tensor`, *optional*, defaults to `None`): Latent noise tensors to replace the random initialization. These can be taken from the previous function call's output. generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`): Random number generator object to ensure reproducibility. output_type (`str`, *optional*, defaults to `"np"`): Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted values are: `"np"` (numpy array) or `"pt"` (torch tensor). output_uncertainty (`bool`, *optional*, defaults to `False`): When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that the `ensemble_size` argument is set to a value above 2. output_latent (`bool`, *optional*, defaults to `False`): When enabled, the output's `latent` field contains the latent codes corresponding to the predictions within the ensemble. These codes can be saved, modified, and used for subsequent calls with the `latents` argument. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a `tuple` is returned where the first element is the prediction, the second element is the uncertainty (or `None`), and the third is the latent (or `None`). """ # 0. Resolving variables. device = self._execution_device dtype = self.dtype # Model-specific optimal default values leading to fast and reasonable results. if num_inference_steps is None: num_inference_steps = self.default_denoising_steps if processing_resolution is None: processing_resolution = self.default_processing_resolution # 1. Check inputs. num_images = self.check_inputs( image, num_inference_steps, ensemble_size, processing_resolution, resample_method_input, resample_method_output, batch_size, ensembling_kwargs, latents, generator, output_type, output_uncertainty, ) # 2. Prepare empty text conditioning. # Model invocation: self.tokenizer, self.text_encoder. if self.empty_text_embedding is None: prompt = "" text_inputs = self.tokenizer( prompt, padding="do_not_pad", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of # operation and leads to the most reasonable results. Using the native image resolution or any other processing # resolution can lead to loss of either fine details or global context in the output predictions. image, padding, original_resolution = self.image_processor.preprocess( image, processing_resolution, resample_method_input, device, dtype ) # [N,3,PPH,PPW] # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. # Latents of each such predictions across all input images and all ensemble members are represented in the # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. # Model invocation: self.vae.encoder. image_latent, pred_latent = self.prepare_latents( image, latents, generator, ensemble_size, batch_size ) # [N*E,4,h,w], [N*E,4,h,w] del image batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat( batch_size, 1, 1 ) # [B,1024,2] # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded # model. # Model invocation: self.unet. pred_latents = [] for i in self.progress_bar( range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..." ): batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w] effective_batch_size = batch_image_latent.shape[0] text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] self.scheduler.set_timesteps(num_inference_steps, device=device) for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."): batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w] noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w] batch_pred_latent = self.scheduler.step( noise, t, batch_pred_latent, generator=generator ).prev_sample # [B,4,h,w] pred_latents.append(batch_pred_latent) pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] del ( pred_latents, image_latent, batch_empty_text_embedding, batch_image_latent, batch_pred_latent, text, batch_latent, noise, ) # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. # Model invocation: self.vae.decoder. prediction = torch.cat( [ self.decode_prediction(pred_latent[i : i + batch_size]) for i in range(0, pred_latent.shape[0], batch_size) ], dim=0, ) # [N*E,3,PPH,PPW] if not output_latent: pred_latent = None # 7. Remove padding. The output shape is (PH, PW). prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N` # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for # each group independently, it stacks them respectively into batches of `N` almost final predictions and # uncertainty maps. uncertainty = None if ensemble_size > 1: prediction = prediction.reshape(num_images, ensemble_size, *prediction.shape[1:]) # [N,E,3,PH,PW] prediction = [ self.ensemble_normals(prediction[i], output_uncertainty, **(ensembling_kwargs or {})) for i in range(num_images) ] # [ [[1,3,PH,PW], [1,1,PH,PW]], ... ] prediction, uncertainty = zip(*prediction) # [[1,3,PH,PW], ... ], [[1,1,PH,PW], ... ] prediction = torch.cat(prediction, dim=0) # [N,3,PH,PW] if output_uncertainty: uncertainty = torch.cat(uncertainty, dim=0) # [N,1,PH,PW] else: uncertainty = None # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. # After upsampling, the native resolution normal maps are renormalized to unit length to reduce the artifacts. # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by # setting the `resample_method_output` parameter (e.g., to `"nearest"`). if match_input_resolution: prediction = self.image_processor.resize_antialias( prediction, original_resolution, resample_method_output, is_aa=False ) # [N,3,H,W] prediction = self.normalize_normals(prediction) # [N,3,H,W] if uncertainty is not None and output_uncertainty: uncertainty = self.image_processor.resize_antialias( uncertainty, original_resolution, resample_method_output, is_aa=False ) # [N,1,H,W] # 10. Prepare the final outputs. if output_type == "np": prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] if uncertainty is not None and output_uncertainty: uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N,H,W,1] # 11. Offload all models self.maybe_free_model_hooks() if not return_dict: return (prediction, uncertainty, pred_latent) return MarigoldNormalsOutput( prediction=prediction, uncertainty=uncertainty, latent=pred_latent, ) # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents def prepare_latents( self, image: torch.Tensor, latents: Optional[torch.Tensor], generator: Optional[torch.Generator], ensemble_size: int, batch_size: int, ) -> Tuple[torch.Tensor, torch.Tensor]: def retrieve_latents(encoder_output): if hasattr(encoder_output, "latent_dist"): return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") image_latent = torch.cat( [ retrieve_latents(self.vae.encode(image[i : i + batch_size])) for i in range(0, image.shape[0], batch_size) ], dim=0, ) # [N,4,h,w] image_latent = image_latent * self.vae.config.scaling_factor image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] pred_latent = latents if pred_latent is None: pred_latent = randn_tensor( image_latent.shape, generator=generator, device=image_latent.device, dtype=image_latent.dtype, ) # [N*E,4,h,w] return image_latent, pred_latent def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: raise ValueError( f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." ) prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] prediction = torch.clip(prediction, -1.0, 1.0) if not self.use_full_z_range: prediction[:, 2, :, :] *= 0.5 prediction[:, 2, :, :] += 0.5 prediction = self.normalize_normals(prediction) # [B,3,H,W] return prediction # [B,3,H,W] @staticmethod def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: if normals.dim() != 4 or normals.shape[1] != 3: raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") norm = torch.norm(normals, dim=1, keepdim=True) normals /= norm.clamp(min=eps) return normals @staticmethod def ensemble_normals( normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest" ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is the number of ensemble members for a given prediction of size `(H x W)`. Args: normals (`torch.Tensor`): Input ensemble normals maps. output_uncertainty (`bool`, *optional*, defaults to `False`): Whether to output uncertainty map. reduction (`str`, *optional*, defaults to `"closest"`): Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and `"mean"`. Returns: A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of uncertainties of shape `(1, 1, H, W)`. """ if normals.dim() != 4 or normals.shape[1] != 3: raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") if reduction not in ("closest", "mean"): raise ValueError(f"Unrecognized reduction method: {reduction}.") mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W] mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W] sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W] sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16 uncertainty = None if output_uncertainty: uncertainty = sim_cos.arccos() # [E,1,H,W] uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W] if reduction == "mean": return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W] closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W] closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W] return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] ================================================ FILE: src/diffusers/pipelines/musicldm/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, is_transformers_version, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_musicldm"] = ["MusicLDMPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_musicldm import MusicLDMPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/musicldm/pipeline_musicldm.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch from transformers import ( ClapFeatureExtractor, ClapModel, ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan, ) from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, is_librosa_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin if is_librosa_available(): import librosa logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import MusicLDMPipeline >>> import torch >>> import scipy >>> repo_id = "ucsd-reach/musicldm" >>> pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs" >>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0] >>> # save the audio sample as a .wav file >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio) ``` """ class MusicLDMPipeline(DiffusionPipeline, StableDiffusionMixin): r""" Pipeline for text-to-audio generation using MusicLDM. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.ClapModel`]): Frozen text-audio embedding model (`ClapTextModel`), specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. tokenizer ([`PreTrainedTokenizer`]): A [`~transformers.RobertaTokenizer`] to tokenize text. feature_extractor ([`~transformers.ClapFeatureExtractor`]): Feature extractor to compute mel-spectrograms from audio waveforms. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded audio latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. vocoder ([`~transformers.SpeechT5HifiGan`]): Vocoder of class `SpeechT5HifiGan`. """ def __init__( self, vae: AutoencoderKL, text_encoder: Union[ClapTextModelWithProjection, ClapModel], tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], feature_extractor: Optional[ClapFeatureExtractor], unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, vocoder: SpeechT5HifiGan, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, unet=unet, scheduler=scheduler, vocoder=vocoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def _encode_prompt( self, prompt, device, num_waveforms_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device (`torch.device`): torch device num_waveforms_per_prompt (`int`): number of waveforms that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the audio generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLAP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder.get_text_features( text_input_ids.to(device), attention_mask=attention_mask.to(device), ) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device) ( bs_embed, seq_len, ) = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt) prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_input.input_ids.to(device) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder.get_text_features( uncond_input_ids, attention_mask=attention_mask, ) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform def mel_spectrogram_to_waveform(self, mel_spectrogram): if mel_spectrogram.dim() == 4: mel_spectrogram = mel_spectrogram.squeeze(1) waveform = self.vocoder(mel_spectrogram) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 waveform = waveform.cpu().float() return waveform # Copied from diffusers.pipelines.audioldm2.pipeline_audioldm2.AudioLDM2Pipeline.score_waveforms def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype): if not is_librosa_available(): logger.info( "Automatic scoring of the generated audio waveforms against the input prompt text requires the " "`librosa` package to resample the generated waveforms. Returning the audios in the order they were " "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`." ) return audio inputs = self.tokenizer(text, return_tensors="pt", padding=True) resampled_audio = librosa.resample( audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate ) inputs["input_features"] = self.feature_extractor( list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate ).input_features.type(dtype) inputs = inputs.to(device) # compute the audio-text similarity score using the CLAP model logits_per_text = self.text_encoder(**inputs).logits_per_text # sort by the highest matching generations per prompt indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt] audio = torch.index_select(audio, 0, indices.reshape(-1).cpu()) return audio # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.check_inputs def check_inputs( self, prompt, audio_length_in_s, vocoder_upsample_factor, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor if audio_length_in_s < min_audio_length_in_s: raise ValueError( f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " f"is {audio_length_in_s}." ) if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: raise ValueError( f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " f"{self.vae_scale_factor}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(self.vocoder.config.model_in_dim) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) model_sequence = [ self.text_encoder.text_model, self.text_encoder.text_projection, self.unet, self.vae, self.vocoder, self.text_encoder, ] hook = None for cpu_offloaded_model in model_sequence: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, audio_length_in_s: Optional[float] = None, num_inference_steps: int = 200, guidance_scale: float = 2.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_waveforms_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, output_type: Optional[str] = "np", ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. audio_length_in_s (`int`, *optional*, defaults to 10.24): The length of the generated audio sample in seconds. num_inference_steps (`int`, *optional*, defaults to 200): The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 2.0): A higher guidance scale value encourages the model to generate audio that is closely linked to the text `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in audio generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_waveforms_per_prompt (`int`, *optional*, defaults to 1): The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, the text encoding model is a joint text-audio model ([`~transformers.ClapModel`]), and the tokenizer is a `[~transformers.ClapProcessor]`, then automatic scoring will be performed between the generated outputs and the input text. This scoring ranks the generated waveforms based on their cosine similarity to text input in the joint text-audio embedding space. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion model (LDM) output. Examples: Returns: [`~pipelines.AudioPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated audio. """ # 0. Convert audio input length from seconds to spectrogram height vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate if audio_length_in_s is None: audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor height = int(audio_length_in_s / vocoder_upsample_factor) original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) if height % self.vae_scale_factor != 0: height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor logger.info( f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " f"denoising process." ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, audio_length_in_s, vocoder_upsample_factor, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_waveforms_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_waveforms_per_prompt, num_channels_latents, height, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=None, class_labels=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) self.maybe_free_model_hooks() # 8. Post-processing if not output_type == "latent": latents = 1 / self.vae.config.scaling_factor * latents mel_spectrogram = self.vae.decode(latents).sample else: return AudioPipelineOutput(audios=latents) audio = self.mel_spectrogram_to_waveform(mel_spectrogram) audio = audio[:, :original_waveform_length] # 9. Automatic scoring if num_waveforms_per_prompt > 1 and prompt is not None: audio = self.score_waveforms( text=prompt, audio=audio, num_waveforms_per_prompt=num_waveforms_per_prompt, device=device, dtype=prompt_embeds.dtype, ) if output_type == "np": audio = audio.numpy() if not return_dict: return (audio,) return AudioPipelineOutput(audios=audio) ================================================ FILE: src/diffusers/pipelines/onnx_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil from pathlib import Path from typing import Optional, Union import numpy as np from huggingface_hub import hf_hub_download from huggingface_hub.utils import validate_hf_hub_args from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging if is_onnx_available(): import onnxruntime as ort logger = logging.get_logger(__name__) ORT_TO_NP_TYPE = { "tensor(bool)": np.bool_, "tensor(int8)": np.int8, "tensor(uint8)": np.uint8, "tensor(int16)": np.int16, "tensor(uint16)": np.uint16, "tensor(int32)": np.int32, "tensor(uint32)": np.uint32, "tensor(int64)": np.int64, "tensor(uint64)": np.uint64, "tensor(float16)": np.float16, "tensor(float)": np.float32, "tensor(double)": np.float64, } class OnnxRuntimeModel: def __init__(self, model=None, **kwargs): logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") self.model = model self.model_save_dir = kwargs.get("model_save_dir", None) self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME) def __call__(self, **kwargs): inputs = {k: np.array(v) for k, v in kwargs.items()} return self.model.run(None, inputs) @staticmethod def load_model(path: Union[str, Path], provider=None, sess_options=None): """ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` Arguments: path (`str` or `Path`): Directory from which to load provider(`str`, *optional*): Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` """ if provider is None: logger.info("No onnxruntime provider specified, using CPUExecutionProvider") provider = "CPUExecutionProvider" return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the latest_model_name. Arguments: save_directory (`str` or `Path`): Directory where to save the model file. file_name(`str`, *optional*): Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the model with a different name. """ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME src_path = self.model_save_dir.joinpath(self.latest_model_name) dst_path = Path(save_directory).joinpath(model_file_name) try: shutil.copyfile(src_path, dst_path) except shutil.SameFileError: pass # copy external weights (for models >2GB) src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) if src_path.exists(): dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) try: shutil.copyfile(src_path, dst_path) except shutil.SameFileError: pass def save_pretrained( self, save_directory: Union[str, os.PathLike], **kwargs, ): """ Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class method.: Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) # saving model weights/files self._save_pretrained(save_directory, **kwargs) @classmethod @validate_hf_hub_args def _from_pretrained( cls, model_id: Union[str, Path], token: Optional[Union[bool, str, None]] = None, revision: Optional[Union[str, None]] = None, force_download: bool = False, cache_dir: Optional[str] = None, file_name: Optional[str] = None, provider: Optional[str] = None, sess_options: Optional["ort.SessionOptions"] = None, **kwargs, ): """ Load a model from a directory or the HF Hub. Arguments: model_id (`str` or `Path`): Directory from which to load token (`str` or `bool`): Is needed to load models from a private or gated repository revision (`str`): Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id cache_dir (`Union[str, Path]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. file_name(`str`): Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load different model files from the same repository or directory. provider(`str`): The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. kwargs (`Dict`, *optional*): kwargs will be passed to the model during initialization """ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME # load model from local directory if os.path.isdir(model_id): model = OnnxRuntimeModel.load_model( Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options ) kwargs["model_save_dir"] = Path(model_id) # load model from hub else: # download model model_cache_path = hf_hub_download( repo_id=model_id, filename=model_file_name, token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, ) kwargs["model_save_dir"] = Path(model_cache_path).parent kwargs["latest_model_name"] = Path(model_cache_path).name model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options) return cls(model=model, **kwargs) @classmethod @validate_hf_hub_args def from_pretrained( cls, model_id: Union[str, Path], force_download: bool = True, token: Optional[str] = None, cache_dir: Optional[str] = None, **model_kwargs, ): revision = None if len(str(model_id).split("@")) == 2: model_id, revision = model_id.split("@") return cls._from_pretrained( model_id=model_id, revision=revision, cache_dir=cache_dir, force_download=force_download, token=token, **model_kwargs, ) ================================================ FILE: src/diffusers/pipelines/pag/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_flax_available, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"] _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"] _import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"] _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"] _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] _import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"] _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"] _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"] _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline from .pipeline_pag_kolors import KolorsPAGPipeline from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline from .pipeline_pag_sd import StableDiffusionPAGPipeline from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/pag/pag_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from typing import Dict, List, Tuple, Union import torch import torch.nn as nn from ...models.attention_processor import ( Attention, AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, ) from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name class PAGMixin: r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1).""" def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): r""" Set the attention processor for the PAG layers. """ pag_attn_processors = self._pag_attn_processors if pag_attn_processors is None: raise ValueError( "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters." ) pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] if hasattr(self, "unet"): model: nn.Module = self.unet else: model: nn.Module = self.transformer def is_self_attn(module: nn.Module) -> bool: r""" Check if the module is self-attention module based on its name. """ return isinstance(module, Attention) and not module.is_cross_attention def is_fake_integral_match(layer_id, name): layer_id = layer_id.split(".")[-1] name = name.split(".")[-1] return layer_id.isnumeric() and name.isnumeric() and layer_id == name for layer_id in pag_applied_layers: # for each PAG layer input, we find corresponding self-attention layers in the unet model target_modules = [] for name, module in model.named_modules(): # Identify the following simple cases: # (1) Self Attention layer existing # (2) Whether the module name matches pag layer id even partially # (3) Make sure it's not a fake integral match if the layer_id ends with a number # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" if ( is_self_attn(module) and re.search(layer_id, name) is not None and not is_fake_integral_match(layer_id, name) ): logger.debug(f"Applying PAG to layer: {name}") target_modules.append(module) if len(target_modules) == 0: raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") for module in target_modules: module.processor = pag_attn_proc def _get_pag_scale(self, t): r""" Get the scale factor for the perturbed attention guidance at timestep `t`. """ if self.do_pag_adaptive_scaling: signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) if signal_scale < 0: signal_scale = 0 return signal_scale else: return self.pag_scale def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): r""" Apply perturbed attention guidance to the noise prediction. Args: noise_pred (torch.Tensor): The noise prediction tensor. do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. guidance_scale (float): The scale factor for the guidance term. t (int): The current time step. Returns: torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. """ pag_scale = self._get_pag_scale(t) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) noise_pred = ( noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + pag_scale * (noise_pred_text - noise_pred_perturb) ) else: noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) return noise_pred def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): """ Prepares the perturbed attention guidance for the PAG model. Args: cond (torch.Tensor): The conditional input tensor. uncond (torch.Tensor): The unconditional input tensor. do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. Returns: torch.Tensor: The prepared perturbed attention guidance tensor. """ cond = torch.cat([cond] * 2, dim=0) if do_classifier_free_guidance: cond = torch.cat([uncond, cond], dim=0) return cond def set_pag_applied_layers( self, pag_applied_layers: Union[str, List[str]], pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0(), ), ): r""" Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. Args: pag_applied_layers (`str` or `List[str]`): One or more strings identifying the layer names, or a simple regex for matching multiple layers, where PAG is to be applied. A few ways of expected usage are as follows: - Single layers specified as - "blocks.{layer_index}" - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] - Multiple layers as a block name - "mid" - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" pag_attn_processors: (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second attention processor is for PAG with CFG disabled (unconditional only). """ if not hasattr(self, "_pag_attn_processors"): self._pag_attn_processors = None if not isinstance(pag_applied_layers, list): pag_applied_layers = [pag_applied_layers] if pag_attn_processors is not None: if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: raise ValueError("Expected a tuple of two attention processors") for i in range(len(pag_applied_layers)): if not isinstance(pag_applied_layers[i], str): raise ValueError( f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" ) self.pag_applied_layers = pag_applied_layers self._pag_attn_processors = pag_attn_processors @property def pag_scale(self) -> float: r"""Get the scale factor for the perturbed attention guidance.""" return self._pag_scale @property def pag_adaptive_scale(self) -> float: r"""Get the adaptive scale factor for the perturbed attention guidance.""" return self._pag_adaptive_scale @property def do_pag_adaptive_scaling(self) -> bool: r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property def do_perturbed_attention_guidance(self) -> bool: r"""Check if the perturbed attention guidance is enabled.""" return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property def pag_attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model with the key as the name of the layer. """ if self._pag_attn_processors is None: return {} valid_attn_processors = {x.__class__ for x in self._pag_attn_processors} processors = {} # We could have iterated through the self.components.items() and checked if a component is # `ModelMixin` subclassed but that can include a VAE too. if hasattr(self, "unet"): denoiser_module = self.unet elif hasattr(self, "transformer"): denoiser_module = self.transformer else: raise ValueError("No denoiser module found.") for name, proc in denoiser_module.attn_processors.items(): if proc.__class__ in valid_attn_processors: processors[name] = proc return processors ================================================ FILE: src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..controlnet.multicontrolnet import MultiControlNetModel from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .pag_utils import PAGMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import AutoPipelineForText2Image, ControlNetModel, UniPCMultistepScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" ... ) >>> image = np.array(image) >>> # get canny image >>> image = cv2.Canny(image, 100, 200) >>> image = image[:, :, None] >>> image = np.concatenate([image, image, image], axis=2) >>> canny_image = Image.fromarray(image) >>> # load control net and stable diffusion v1-5 >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, enable_pag=True ... ) >>> # speed up diffusion process with faster scheduler and memory optimization >>> # remove following line if xformers is not installed >>> pipe.enable_xformers_memory_efficient_attention() >>> pipe.enable_model_cpu_offload() >>> # generate image >>> generator = torch.manual_seed(0) >>> image = pipe( ... "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting", ... guidance_scale=7.5, ... generator=generator, ... image=canny_image, ... pag_scale=10, ... ).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionControlNetPAGPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, pag_applied_layers: Union[str, List[str]] = "mid", ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) self.set_pag_applied_layers(pag_applied_layers) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): transposed_image = [list(t) for t in zip(*image)] if len(transposed_image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in transposed_image: self.check_image(image_, prompt, prompt_embeds) elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) else: for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError( "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " "The conditioning scale must be fixed across the batch." ) elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] if not isinstance(control_guidance_end, (tuple, list)): control_guidance_end = [control_guidance_end] if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): The ControlNet encoder tries to recognize the content of the input image even if you remove all prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. pag_scale (`float`, *optional*, defaults to 3.0): The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention guidance will not be used. pag_adaptive_scale (`float`, *optional*, defaults to 0.0): The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is used. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._pag_scale = pag_scale self._pag_adaptive_scale = pag_adaptive_scale # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] # Nested lists as ControlNet condition if isinstance(image[0], list): # Transpose the nested image list image = [list(t) for t in zip(*image)] for image_ in image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter if ip_adapter_image is not None or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) for i, image_embeds in enumerate(ip_adapter_image_embeds): negative_image_embeds = None if self.do_classifier_free_guidance: negative_image_embeds, image_embeds = image_embeds.chunk(2) if self.do_perturbed_attention_guidance: image_embeds = self._prepare_perturbed_attention_guidance( image_embeds, negative_image_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) image_embeds = image_embeds.to(device) ip_adapter_image_embeds[i] = image_embeds added_cond_kwargs = ( {"image_embeds": ip_adapter_image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) controlnet_prompt_embeds = prompt_embeds # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) images = image if isinstance(image, list) else [image] for i, single_image in enumerate(images): if self.do_classifier_free_guidance: single_image = single_image.chunk(2)[0] if self.do_perturbed_attention_guidance: single_image = self._prepare_perturbed_attention_guidance( single_image, single_image, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: single_image = torch.cat([single_image] * 2) single_image = single_image.to(device) images[i] = single_image image = images if isinstance(image, list) else images[0] # 8. Denoising loop if self.do_perturbed_attention_guidance: original_attn_proc = self.unet.attn_processors self._set_pag_attn_processor( pag_applied_layers=self.pag_applied_layers, do_classifier_free_guidance=self.do_classifier_free_guidance, ) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference control_model_input = latent_model_input if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and self.do_classifier_free_guidance: # Inferred ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from diffusers.utils.import_utils import is_invisible_watermark_available from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from .pag_utils import PAGMixin if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from ..controlnet.multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import AutoPipelineForText2Image, ControlNetModel, AutoencoderKL >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" >>> negative_prompt = "low quality, bad quality, sketches" >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" ... ) >>> # initialize the models and pipeline >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization >>> controlnet = ControlNetModel.from_pretrained( ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 ... ) >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", ... controlnet=controlnet, ... vae=vae, ... torch_dtype=torch.float16, ... enable_pag=True, ... ) >>> pipe.enable_model_cpu_offload() >>> # get canny image >>> image = np.array(image) >>> image = cv2.Canny(image, 100, 200) >>> image = image[:, :, None] >>> image = np.concatenate([image, image, image], axis=2) >>> canny_image = Image.fromarray(image) >>> # generate image >>> image = pipe( ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image, pag_scale=0.3 ... ).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionXLControlNetPAGPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): Second frozen text-encoder ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. tokenizer_2 ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings should always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no watermarker is used. """ # leave controlnet out on purpose because it iterates with unet model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "feature_extractor", "image_encoder", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, pag_applied_layers: Union[str, List[str]] = "mid", # ["down.block_2", "up.block_1.attentions_0"], "mid" ): super().__init__() if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, controlnet=controlnet, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.set_pag_applied_layers(pag_applied_layers) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_inputs def check_inputs( self, prompt, prompt_2, image, callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, negative_pooled_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, ): if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if not isinstance(control_guidance_start, (tuple, list)): control_guidance_start = [control_guidance_start] if not isinstance(control_guidance_end, (tuple, list)): control_guidance_end = [control_guidance_end] if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled text embeddings are generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the ControlNet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the ControlNet stops applying. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. pag_scale (`float`, *optional*, defaults to 3.0): The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention guidance will not be used. pag_adaptive_scale (`float`, *optional*, defaults to 0.0): The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is used. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned containing the output images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, image, None, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, negative_pooled_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._pag_scale = pag_scale self._pag_adaptive_scale = pag_adaptive_scale # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) # 3.1 Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt, prompt_2, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # 3.2 Encode ip_adapter_image if ip_adapter_image is not None or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=False, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] for image_ in image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=False, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) self._num_timesteps = len(timesteps) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 7.2 Prepare added time ids & embeddings if isinstance(image, list): original_size = original_size or image[0].shape[-2:] else: original_size = original_size or image.shape[-2:] target_size = target_size or (height, width) add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids images = image if isinstance(image, list) else [image] for i, single_image in enumerate(images): if self.do_classifier_free_guidance: single_image = single_image.chunk(2)[0] if self.do_perturbed_attention_guidance: single_image = self._prepare_perturbed_attention_guidance( single_image, single_image, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: single_image = torch.cat([single_image] * 2) single_image = single_image.to(device) images[i] = single_image image = images if isinstance(image, list) else images[0] if ip_adapter_image_embeds is not None: for i, image_embeds in enumerate(ip_adapter_image_embeds): negative_image_embeds = None if self.do_classifier_free_guidance: negative_image_embeds, image_embeds = image_embeds.chunk(2) if self.do_perturbed_attention_guidance: image_embeds = self._prepare_perturbed_attention_guidance( image_embeds, negative_image_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) image_embeds = image_embeds.to(device) ip_adapter_image_embeds[i] = image_embeds if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance ) add_text_embeds = self._prepare_perturbed_attention_guidance( add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance ) add_time_ids = self._prepare_perturbed_attention_guidance( add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = added_cond_kwargs # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 8.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] if self.do_perturbed_attention_guidance: original_attn_proc = self.unet.attn_processors self._set_pag_attn_processor( pag_applied_layers=self.pag_applied_layers, do_classifier_free_guidance=self.do_classifier_free_guidance, ) is_unet_compiled = is_compiled_module(self.unet) is_controlnet_compiled = is_compiled_module(self.controlnet) is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference control_model_input = latent_model_input if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=False, added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) if ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py ================================================ # Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, HunyuanDiT2DModel from ...models.attention_processor import PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0 from ...models.embeddings import get_2d_rotary_pos_embed from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import DDPMScheduler from ...utils import ( is_torch_xla_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pag_utils import PAGMixin if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```python >>> import torch >>> from diffusers import AutoPipelineForText2Image >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", ... torch_dtype=torch.float16, ... enable_pag=True, ... pag_applied_layers=[14], ... ).to("cuda") >>> # prompt = "an astronaut riding a horse" >>> prompt = "一个宇航员在骑马" >>> image = pipe(prompt, guidance_scale=4, pag_scale=3).images[0] ``` """ STANDARD_RATIO = np.array( [ 1.0, # 1:1 4.0 / 3.0, # 4:3 3.0 / 4.0, # 3:4 16.0 / 9.0, # 16:9 9.0 / 16.0, # 9:16 ] ) STANDARD_SHAPE = [ [(1024, 1024), (1280, 1280)], # 1:1 [(1024, 768), (1152, 864), (1280, 960)], # 4:3 [(768, 1024), (864, 1152), (960, 1280)], # 3:4 [(1280, 768)], # 16:9 [(768, 1280)], # 9:16 ] STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] SUPPORTED_SHAPE = [ (1024, 1024), (1280, 1280), # 1:1 (1024, 768), (1152, 864), (1280, 960), # 4:3 (768, 1024), (864, 1152), (960, 1280), # 3:4 (1280, 768), # 16:9 (768, 1280), # 9:16 ] def map_to_standard_shapes(target_width, target_height): target_ratio = target_width / target_height closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] return width, height def get_resize_crop_region_for_grid(src, tgt_size): th = tw = tgt_size h, w = src r = h / w # resize if r > 1: resize_height = th resize_width = int(round(th / h * w)) else: resize_width = tw resize_height = int(round(tw / w * h)) crop_top = int(round((th - resize_height) / 2.0)) crop_left = int(round((tw - resize_width) / 2.0)) return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): r""" Pipeline for English/Chinese-to-image generation using HunyuanDiT and [Perturbed Attention Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by ourselves) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use `sdxl-vae-fp16-fix`. text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). HunyuanDiT uses a fine-tuned [bilingual CLIP]. tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): A `BertTokenizer` or `CLIPTokenizer` to tokenize text. transformer ([`HunyuanDiT2DModel`]): The HunyuanDiT model designed by Tencent Hunyuan. text_encoder_2 (`T5EncoderModel`): The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. tokenizer_2 (`MT5Tokenizer`): The tokenizer for the mT5 embedder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = [ "safety_checker", "feature_extractor", "text_encoder_2", "tokenizer_2", "text_encoder", "tokenizer", ] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "prompt_embeds_2", "negative_prompt_embeds_2", ] def __init__( self, vae: AutoencoderKL, text_encoder: BertModel, tokenizer: BertTokenizer, transformer: HunyuanDiT2DModel, scheduler: DDPMScheduler, safety_checker: Optional[StableDiffusionSafetyChecker] = None, feature_extractor: Optional[CLIPImageProcessor] = None, requires_safety_checker: bool = True, text_encoder_2: Optional[T5EncoderModel] = None, tokenizer_2: Optional[MT5Tokenizer] = None, pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16 ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, text_encoder_2=text_encoder_2, ) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.default_sample_size = ( self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128 ) self.set_pag_applied_layers( pag_applied_layers, pag_attn_processors=(PAGCFGHunyuanAttnProcessor2_0(), PAGHunyuanAttnProcessor2_0()) ) # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt def encode_prompt( self, prompt: str, device: torch.device = None, dtype: torch.dtype = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: Optional[int] = None, text_encoder_index: int = 0, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device dtype (`torch.dtype`): torch dtype num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the prompt. Required when `prompt_embeds` is passed directly. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. text_encoder_index (`int`, *optional*): Index of the text encoder to use. `0` for clip and `1` for T5. """ if dtype is None: if self.text_encoder_2 is not None: dtype = self.text_encoder_2.dtype elif self.transformer is not None: dtype = self.transformer.dtype else: dtype = None if device is None: device = self._execution_device tokenizers = [self.tokenizer, self.tokenizer_2] text_encoders = [self.text_encoder, self.text_encoder_2] tokenizer = tokenizers[text_encoder_index] text_encoder = text_encoders[text_encoder_index] if max_sequence_length is None: if text_encoder_index == 0: max_length = 77 if text_encoder_index == 1: max_length = 256 else: max_length = max_sequence_length if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = text_encoder( text_input_ids.to(device), attention_mask=prompt_attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs def check_inputs( self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, prompt_embeds_2=None, negative_prompt_embeds_2=None, prompt_attention_mask_2=None, negative_prompt_attention_mask_2=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is None and prompt_embeds_2 is None: raise ValueError( "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: raise ValueError( "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: raise ValueError( "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" f" {negative_prompt_embeds_2.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_2: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_2: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, prompt_attention_mask_2: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = (1024, 1024), target_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), use_resolution_binning: bool = True, pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): r""" The call function to the pipeline for generation with HunyuanDiT. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`): The height in pixels of the generated image. width (`int`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. prompt_embeds_2 (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. negative_prompt_embeds_2 (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the prompt. Required when `prompt_embeds` is passed directly. prompt_attention_mask_2 (`torch.Tensor`, *optional*): Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A callback function or a list of callback functions to be called at the end of each denoising step. callback_on_step_end_tensor_inputs (`List[str]`, *optional*): A list of tensor inputs that should be passed to the callback function. If not defined, all tensor inputs will be passed. guidance_rescale (`float`, *optional*, defaults to 0.0): Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): The original size of the image. Used to calculate the time ids. target_size (`Tuple[int, int]`, *optional*): The target size of the image. Used to calculate the time ids. crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): The top left coordinates of the crop. Used to calculate the time ids. use_resolution_binning (`bool`, *optional*, defaults to `True`): Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. pag_scale (`float`, *optional*, defaults to 3.0): The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention guidance will not be used. pag_adaptive_scale (`float`, *optional*, defaults to 0.0): The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is used. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor height = int((height // 16) * 16) width = int((width // 16) * 16) if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: width, height = map_to_standard_shapes(width, height) height = int(height) width = int(width) logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, prompt_embeds_2, negative_prompt_embeds_2, prompt_attention_mask_2, negative_prompt_attention_mask_2, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._interrupt = False self._pag_scale = pag_scale self._pag_adaptive_scale = pag_adaptive_scale # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt ( prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, device=device, dtype=self.transformer.dtype, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=77, text_encoder_index=0, ) ( prompt_embeds_2, negative_prompt_embeds_2, prompt_attention_mask_2, negative_prompt_attention_mask_2, ) = self.encode_prompt( prompt=prompt, device=device, dtype=self.transformer.dtype, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds_2, negative_prompt_embeds=negative_prompt_embeds_2, prompt_attention_mask=prompt_attention_mask_2, negative_prompt_attention_mask=negative_prompt_attention_mask_2, max_sequence_length=256, text_encoder_index=1, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Create image_rotary_emb, style embedding & time ids grid_height = height // 8 // self.transformer.config.patch_size grid_width = width // 8 // self.transformer.config.patch_size base_size = 512 // 8 // self.transformer.config.patch_size grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) image_rotary_emb = get_2d_rotary_pos_embed( self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) ) style = torch.tensor([0], device=device) target_size = target_size or (height, width) add_time_ids = list(original_size + target_size + crops_coords_top_left) add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance ) prompt_attention_mask = self._prepare_perturbed_attention_guidance( prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance ) prompt_embeds_2 = self._prepare_perturbed_attention_guidance( prompt_embeds_2, negative_prompt_embeds_2, self.do_classifier_free_guidance ) prompt_attention_mask_2 = self._prepare_perturbed_attention_guidance( prompt_attention_mask_2, negative_prompt_attention_mask_2, self.do_classifier_free_guidance ) add_time_ids = torch.cat([add_time_ids] * 3, dim=0) style = torch.cat([style] * 3, dim=0) elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) add_time_ids = torch.cat([add_time_ids] * 2, dim=0) style = torch.cat([style] * 2, dim=0) prompt_embeds = prompt_embeds.to(device=device) prompt_attention_mask = prompt_attention_mask.to(device=device) prompt_embeds_2 = prompt_embeds_2.to(device=device) prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( batch_size * num_images_per_prompt, 1 ) style = style.to(device=device).repeat(batch_size * num_images_per_prompt) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) if self.do_perturbed_attention_guidance: original_attn_proc = self.transformer.attn_processors self._set_pag_attn_processor( pag_applied_layers=self.pag_applied_layers, do_classifier_free_guidance=self.do_classifier_free_guidance, ) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( dtype=latent_model_input.dtype ) # predict the noise residual noise_pred = self.transformer( latent_model_input, t_expand, encoder_hidden_states=prompt_embeds, text_embedding_mask=prompt_attention_mask, encoder_hidden_states_t5=prompt_embeds_2, text_embedding_mask_t5=prompt_attention_mask_2, image_meta_size=add_time_ids, style=style, image_rotary_emb=image_rotary_emb, return_dict=False, )[0] noise_pred, _ = noise_pred.chunk(2, dim=1) # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) negative_prompt_embeds_2 = callback_outputs.pop( "negative_prompt_embeds_2", negative_prompt_embeds_2 ) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # 9. Offload all models self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: self.transformer.set_attn_processor(original_attn_proc) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/pag/pipeline_pag_kolors.py ================================================ # Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..kolors.pipeline_output import KolorsPipelineOutput from ..kolors.text_encoder import ChatGLMModel from ..kolors.tokenizer import ChatGLMTokenizer from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pag_utils import PAGMixin if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AutoPipelineForText2Image >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "Kwai-Kolors/Kolors-diffusers", ... variant="fp16", ... torch_dtype=torch.float16, ... enable_pag=True, ... pag_applied_layers=["down.block_2.attentions_1", "up.block_0.attentions_1"], ... ) >>> pipe = pipe.to("cuda") >>> prompt = ( ... "A photo of a ladybug, macro, zoom, high quality, film, holding a wooden sign with the text 'KOLORS'" ... ) >>> image = pipe(prompt, guidance_scale=5.5, pag_scale=1.5).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class KolorsPAGPipeline( DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, PAGMixin ): r""" Pipeline for text-to-image generation using Kolors. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`ChatGLMModel`]): Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b). tokenizer (`ChatGLMTokenizer`): Tokenizer of class [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `Kwai-Kolors/Kolors-diffusers`. pag_applied_layers (`str` or `List[str]``, *optional*, defaults to `"mid"`): Set the transformer attention layers where to apply the perturbed attention guidance. Can be a string or a list of strings with "down", "mid", "up", a whole transformer block or specific transformer block attention layers, e.g.: ["mid"] ["down", "mid"] ["down", "mid", "up.block_1"] ["down", "mid", "up.block_1.attentions_0", "up.block_1.attentions_1"] """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = [ "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: ChatGLMModel, tokenizer: ChatGLMTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = False, pag_applied_layers: Union[str, List[str]] = "mid", ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size self.set_pag_applied_layers(pag_applied_layers) # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt def encode_prompt( self, prompt, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 256, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. """ # from IPython import embed; embed(); exit() device = device or self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer] text_encoders = [self.text_encoder] if prompt_embeds is None: prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_tensors="pt", ).to(device) output = text_encoder( input_ids=text_inputs["input_ids"], attention_mask=text_inputs["attention_mask"], position_ids=text_inputs["position_ids"], output_hidden_states=True, ) # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] # clone to have a contiguous tensor prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = prompt_embeds_list[0] # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt negative_prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_sequence_length, truncation=True, return_tensors="pt", ).to(device) output = text_encoder( input_ids=uncond_input["input_ids"], attention_mask=uncond_input["attention_mask"], position_ids=uncond_input["position_ids"], output_hidden_states=True, ) # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] # clone to have a contiguous tensor negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view( batch_size * num_images_per_prompt, seq_len, -1 ) negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = negative_prompt_embeds_list[0] bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.check_inputs def check_inputs( self, prompt, num_inference_steps, height, width, negative_prompt=None, prompt_embeds=None, pooled_prompt_embeds=None, negative_prompt_embeds=None, negative_pooled_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: raise ValueError( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) if max_sequence_length is not None and max_sequence_length > 256: raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. pag_scale (`float`, *optional*, defaults to 3.0): The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention guidance will not be used. pag_adaptive_scale (`float`, *optional*, defaults to 0.0): The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is used. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, num_inference_steps, height, width, negative_prompt, prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._interrupt = False self._pag_scale = pag_scale self._pag_adaptive_scale = pag_adaptive_scale # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance ) add_text_embeds = self._prepare_perturbed_attention_guidance( add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance ) add_time_ids = self._prepare_perturbed_attention_guidance( add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) for i, image_embeds in enumerate(ip_adapter_image_embeds): negative_image_embeds = None if self.do_classifier_free_guidance: negative_image_embeds, image_embeds = image_embeds.chunk(2) if self.do_perturbed_attention_guidance: image_embeds = self._prepare_perturbed_attention_guidance( image_embeds, negative_image_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) image_embeds = image_embeds.to(device) ip_adapter_image_embeds[i] = image_embeds # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 8.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) if self.do_perturbed_attention_guidance: original_attn_proc = self.unet.attn_processors self._set_pag_attn_processor( pag_applied_layers=self.pag_applied_layers, do_classifier_free_guidance=self.do_classifier_free_guidance, ) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # unscale/denormalize the latents latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image,) return KolorsPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py ================================================ # Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import html import inspect import re import urllib.parse as ul from typing import Callable, List, Optional, Tuple, Union import torch from transformers import T5EncoderModel, T5Tokenizer from ...image_processor import PixArtImageProcessor from ...models import AutoencoderKL, PixArtTransformer2DModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( BACKENDS_MAPPING, deprecate, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pixart_alpha.pipeline_pixart_alpha import ( ASPECT_RATIO_256_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_1024_BIN, ) from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN from .pag_utils import PAGMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AutoPipelineForText2Image >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", ... torch_dtype=torch.float16, ... pag_applied_layers=["blocks.14"], ... enable_pag=True, ... ) >>> pipe = pipe.to("cuda") >>> prompt = "A small cactus with a happy face in the Sahara desert" >>> image = pipe(prompt, pag_scale=4.0, guidance_scale=1.0).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): r""" [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation using PixArt-Sigma. """ bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder"] model_cpu_offload_seq = "text_encoder->transformer->vae" def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, vae: AutoencoderKL, transformer: PixArtTransformer2DModel, scheduler: KarrasDiffusionSchedulers, pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) self.set_pag_applied_layers(pag_applied_layers) # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, negative_prompt: str = "", num_images_per_prompt: int = 1, device: Optional[torch.device] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, clean_caption: bool = False, max_sequence_length: int = 300, **kwargs, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For PixArt-Alpha, this should be "". do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" string. clean_caption (`bool`, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. """ if "mask_feature" in kwargs: deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # See Section 3.1. of the paper. max_length = max_sequence_length if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because T5 can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.to(device) prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.transformer is not None: dtype = self.transformer.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) negative_prompt_attention_mask = uncond_input.attention_mask negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) else: negative_prompt_embeds = None negative_prompt_attention_mask = None return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs def check_inputs( self, prompt, height, width, negative_prompt, callback_steps, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: raise ValueError( "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" f" {negative_prompt_attention_mask.shape}." ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: str = "", num_inference_steps: int = 20, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, height: Optional[int] = None, width: Optional[int] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, use_resolution_binning: bool = True, max_sequence_length: int = 300, pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 4.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size): The width in pixels of the generated image. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. use_resolution_binning (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to the requested resolution. Useful for generating non-square images. max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`. pag_scale (`float`, *optional*, defaults to 3.0): The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention guidance will not be used. pag_adaptive_scale (`float`, *optional*, defaults to 0.0): The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is used. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor if use_resolution_binning: if self.transformer.config.sample_size == 256: aspect_ratio_bin = ASPECT_RATIO_2048_BIN elif self.transformer.config.sample_size == 128: aspect_ratio_bin = ASPECT_RATIO_1024_BIN elif self.transformer.config.sample_size == 64: aspect_ratio_bin = ASPECT_RATIO_512_BIN elif self.transformer.config.sample_size == 32: aspect_ratio_bin = ASPECT_RATIO_256_BIN else: raise ValueError("Invalid sample size") orig_height, orig_width = height, width height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) self.check_inputs( prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, ) self._pag_scale = pag_scale self._pag_adaptive_scale = pag_adaptive_scale # 2. Default height and width to transformer if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt ( prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( prompt, do_classifier_free_guidance, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, device=device, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, clean_caption=clean_caption, max_sequence_length=max_sequence_length, ) if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, do_classifier_free_guidance ) prompt_attention_mask = self._prepare_perturbed_attention_guidance( prompt_attention_mask, negative_prompt_attention_mask, do_classifier_free_guidance ) elif do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, height, width, prompt_embeds.dtype, device, generator, latents, ) if self.do_perturbed_attention_guidance: original_attn_proc = self.transformer.attn_processors self._set_pag_attn_processor( pag_applied_layers=self.pag_applied_layers, do_classifier_free_guidance=do_classifier_free_guidance, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Prepare micro-conditions. added_cond_kwargs = {"resolution": None, "aspect_ratio": None} # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) current_timestep = t if not torch.is_tensor(current_timestep): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" if isinstance(current_timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML current_timestep = current_timestep.expand(latent_model_input.shape[0]) # predict noise model_output noise_pred = self.transformer( latent_model_input, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, timestep=current_timestep, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, do_classifier_free_guidance, guidance_scale, current_timestep ) elif do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # learned sigma if self.transformer.config.out_channels // 2 == latent_channels: noise_pred = noise_pred.chunk(2, dim=1)[0] else: noise_pred = noise_pred # compute previous image: x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) else: image = latents if not output_type == "latent": image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: self.transformer.set_attn_processor(original_attn_proc) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/pag/pipeline_pag_sd.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .pag_utils import PAGMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AutoPipelineForText2Image >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, pag_scale=0.3).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionPAGPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, pag_applied_layers: Union[str, List[str]] = "mid", ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.set_pag_applied_layers(pag_applied_layers) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. pag_scale (`float`, *optional*, defaults to 3.0): The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention guidance will not be used. pag_adaptive_scale (`float`, *optional*, defaults to 0.0): The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is used. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # to deal with lora scaling and other possible forward hooks # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, None, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False self._pag_scale = pag_scale self._pag_adaptive_scale = pag_adaptive_scale # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) for i, image_embeds in enumerate(ip_adapter_image_embeds): negative_image_embeds = None if self.do_classifier_free_guidance: negative_image_embeds, image_embeds = image_embeds.chunk(2) if self.do_perturbed_attention_guidance: image_embeds = self._prepare_perturbed_attention_guidance( image_embeds, negative_image_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) image_embeds = image_embeds.to(device) ip_adapter_image_embeds[i] = image_embeds # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": ip_adapter_image_embeds} if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) else None ) # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order if self.do_perturbed_attention_guidance: original_attn_proc = self.unet.attn_processors self._set_pag_attn_processor( pag_applied_layers=self.pag_applied_layers, do_classifier_free_guidance=self.do_classifier_free_guidance, ) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/pag/pipeline_pag_sd_3.py ================================================ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, ) from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...models.attention_processor import PAGCFGJointAttnProcessor2_0, PAGJointAttnProcessor2_0 from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput from .pag_utils import PAGMixin if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AutoPipelineForText2Image >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "stabilityai/stable-diffusion-3-medium-diffusers", ... torch_dtype=torch.float16, ... enable_pag=True, ... pag_applied_layers=["blocks.13"], ... ) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> image = pipe(prompt, guidance_scale=5.0, pag_scale=0.7).images[0] >>> image.save("sd3_pag.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, PAGMixin): r""" [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation using Stable Diffusion 3. Args: transformer ([`SD3Transformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` as its dimension. text_encoder_2 ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. text_encoder_3 ([`T5EncoderModel`]): Frozen text-encoder. Stable Diffusion 3 uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). """ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__( self, transformer: SD3Transformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_encoder_2: CLIPTextModelWithProjection, tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) self.default_sample_size = ( self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128 ) self.set_pag_applied_layers( pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0()) ) # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if self.text_encoder_3 is None: return torch.zeros( ( batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim, ), device=device, dtype=dtype, ) text_inputs = self.tokenizer_3( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] dtype = self.text_encoder_3.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clip_skip: Optional[int] = None, clip_model_index: int = 0, ): device = device or self._execution_device clip_tokenizers = [self.tokenizer, self.tokenizer_2] clip_text_encoders = [self.text_encoder, self.text_encoder_2] tokenizer = clip_tokenizers[clip_model_index] text_encoder = clip_text_encoders[clip_model_index] prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, max_sequence_length: int = 256, lora_scale: Optional[float] = None, ): r""" Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in all text-encoders prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 prompt_3 = prompt_3 or prompt prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=0, ) prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( prompt=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=1, ) clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) t5_prompt_embed = self._get_t5_prompt_embeds( prompt=prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) clip_prompt_embeds = torch.nn.functional.pad( clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_3 = negative_prompt_3 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) negative_prompt_3 = ( batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( negative_prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=0, ) negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( negative_prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=1, ) negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) t5_negative_prompt_embed = self._get_t5_prompt_embeds( prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) negative_clip_prompt_embeds = torch.nn.functional.pad( negative_clip_prompt_embeds, (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), ) negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) negative_pooled_prompt_embeds = torch.cat( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.check_inputs def check_inputs( self, prompt, prompt_2, prompt_3, height, width, negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_3 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_3 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): if latents is not None: return latents.to(device=device, dtype=dtype) shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is will be used instead height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used instead negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used instead num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. pag_scale (`float`, *optional*, defaults to 3.0): The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention guidance will not be used. pag_adaptive_scale (`float`, *optional*, defaults to 0.0): The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is used. Examples: Returns: [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, prompt_3, height, width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False self._pag_scale = pag_scale self._pag_adaptive_scale = pag_adaptive_scale # # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance ) pooled_prompt_embeds = self._prepare_perturbed_attention_guidance( pooled_prompt_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) if self.do_perturbed_attention_guidance: original_attn_proc = self.transformer.attn_processors self._set_pag_attn_processor( pag_applied_layers=self.pag_applied_layers, do_classifier_free_guidance=self.do_classifier_free_guidance, ) # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: self.transformer.set_attn_processor(original_attn_proc) if not return_dict: return (image,) return StableDiffusion3PipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..animatediff.pipeline_output import AnimateDiffPipelineOutput from ..free_init_utils import FreeInitMixin from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pag_utils import PAGMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AnimateDiffPAGPipeline, MotionAdapter, DDIMScheduler >>> from diffusers.utils import export_to_gif >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" >>> motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-2" >>> motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id) >>> scheduler = DDIMScheduler.from_pretrained( ... model_id, subfolder="scheduler", beta_schedule="linear", steps_offset=1, clip_sample=False ... ) >>> pipe = AnimateDiffPAGPipeline.from_pretrained( ... model_id, ... motion_adapter=motion_adapter, ... scheduler=scheduler, ... pag_applied_layers=["mid"], ... torch_dtype=torch.float16, ... ).to("cuda") >>> video = pipe( ... prompt="car, futuristic cityscape with neon lights, street, no human", ... negative_prompt="low quality, bad quality", ... num_inference_steps=25, ... guidance_scale=6.0, ... pag_scale=3.0, ... generator=torch.Generator().manual_seed(42), ... ).frames[0] >>> export_to_gif(video, "animatediff_pag.gif") ``` """ class AnimateDiffPAGPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, AnimateDiffFreeNoiseMixin, PAGMixin, ): r""" Pipeline for text-to-video generation using [AnimateDiff](https://huggingface.co/docs/diffusers/en/api/pipelines/animatediff) and [Perturbed Attention Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): A [`~transformers.CLIPTokenizer`] to tokenize text. unet ([`UNet2DConditionModel`]): A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. motion_adapter ([`MotionAdapter`]): A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: Union[UNet2DConditionModel, UNetMotionModel], motion_adapter: MotionAdapter, scheduler: KarrasDiffusionSchedulers, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, pag_applied_layers: Union[str, List[str]] = "mid_block.*attn1", # ["mid"], ["down_blocks.1"] ): super().__init__() if isinstance(unet, UNet2DConditionModel): unet = UNetMotionModel.from_unet2d(unet, motion_adapter) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, motion_adapter=motion_adapter, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) self.set_pag_applied_layers(pag_applied_layers) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents def decode_latents(self, latents, decode_chunk_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] for i in range(0, latents.shape[0], decode_chunk_size): batch_latents = latents[i : i + decode_chunk_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) video = torch.cat(video) video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.pia.pipeline_pia.PIAPipeline.check_inputs def check_inputs( self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) if self.free_noise_enabled: latents = self._prepare_latents_free_noise( batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, num_frames: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], decode_chunk_size: int = 16, pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated video. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated video. num_frames (`int`, *optional*, defaults to 16): The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. pag_scale (`float`, *optional*, defaults to 3.0): The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention guidance will not be used. pag_adaptive_scale (`float`, *optional*, defaults to 0.0): The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is used. Examples: Returns: [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._pag_scale = pag_scale self._pag_adaptive_scale = pag_adaptive_scale # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_videos_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt, self.do_classifier_free_guidance, ) for i, image_embeds in enumerate(ip_adapter_image_embeds): negative_image_embeds = None if self.do_classifier_free_guidance: negative_image_embeds, image_embeds = image_embeds.chunk(2) if self.do_perturbed_attention_guidance: image_embeds = self._prepare_perturbed_attention_guidance( image_embeds, negative_image_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) image_embeds = image_embeds.to(device) ip_adapter_image_embeds[i] = image_embeds # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": ip_adapter_image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) if self.do_perturbed_attention_guidance: original_attn_proc = self.unet.attn_processors self._set_pag_attn_processor( pag_applied_layers=self.pag_applied_layers, do_classifier_free_guidance=self.do_classifier_free_guidance, ) num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: latents, timesteps = self._apply_free_init( latents, free_init_iter, num_inference_steps, device, latents.dtype, generator ) self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # 8. Denoising loop with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # 9. Post processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (video,) return AnimateDiffPipelineOutput(frames=video) ================================================ FILE: src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from .pag_utils import PAGMixin if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AutoPipelineForText2Image >>> pipe = AutoPipelineForText2Image.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", ... torch_dtype=torch.float16, ... enable_pag=True, ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, pag_scale=0.3).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionXLPAGPipeline( DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, IPAdapterMixin, PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"],["down.block_1"],["up.block_0.attentions_0"] ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None self.set_pag_applied_layers(pag_applied_layers) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs def check_inputs( self, prompt, prompt_2, height, width, callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. pag_scale (`float`, *optional*, defaults to 3.0): The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention guidance will not be used. pag_adaptive_scale (`float`, *optional*, defaults to 0.0): The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is used. Examples: Returns: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, None, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._interrupt = False self._pag_scale = pag_scale self._pag_adaptive_scale = pag_adaptive_scale # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance ) add_text_embeds = self._prepare_perturbed_attention_guidance( add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance ) add_time_ids = self._prepare_perturbed_attention_guidance( add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) for i, image_embeds in enumerate(ip_adapter_image_embeds): negative_image_embeds = None if self.do_classifier_free_guidance: negative_image_embeds, image_embeds = image_embeds.chunk(2) if self.do_perturbed_attention_guidance: image_embeds = self._prepare_perturbed_attention_guidance( image_embeds, negative_image_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) image_embeds = image_embeds.to(device) ip_adapter_image_embeds[i] = image_embeds # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 8.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) if self.do_perturbed_attention_guidance: original_attn_proc = self.unet.attn_processors self._set_pag_attn_processor( pag_applied_layers=self.pag_applied_layers, do_classifier_free_guidance=self.do_classifier_free_guidance, ) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL.Image import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from .pag_utils import PAGMixin if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AutoPipelineForImage2Image >>> from diffusers.utils import load_image >>> pipe = AutoPipelineForImage2Image.from_pretrained( ... "stabilityai/stable-diffusion-xl-refiner-1.0", ... torch_dtype=torch.float16, ... enable_pag=True, ... ) >>> pipe = pipe.to("cuda") >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" >>> init_image = load_image(url).convert("RGB") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, image=init_image, pag_scale=0.3).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionXLPAGImg2ImgPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the config of `stabilityai/stable-diffusion-xl-refiner-1-0`. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "add_neg_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, image_encoder=image_encoder, feature_extractor=feature_extractor, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None self.set_pag_applied_layers(pag_applied_layers) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.check_inputs def check_inputs( self, prompt, prompt_2, strength, num_inference_steps, callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if num_inference_steps is None: raise ValueError("`num_inference_steps` cannot be None.") elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: raise ValueError( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) else: t_start = 0 timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] # Strength is irrelevant if we directly request a timestep to start at; # that is, strength is determined by the denoising_start instead. if denoising_start is not None: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() if self.scheduler.order == 2 and num_inference_steps % 2 == 0: # if the scheduler is a 2nd order scheduler we might have to do +1 # because `num_inference_steps` might be even given that every timestep # (except the highest one) is duplicated. If `num_inference_steps` is even it would # mean that we cut the timesteps in the middle of the denoising step # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler num_inference_steps = num_inference_steps + 1 # because t_n+1 >= t_n, we slice the timesteps starting from the end timesteps = timesteps[-num_inference_steps:] return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents def prepare_latents( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) latents_mean = latents_std = None if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") torch.cuda.empty_cache() image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() self.vae.to(dtype=torch.float32) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) init_latents = init_latents.to(dtype) if latents_mean is not None and latents_std is not None: latents_mean = latents_mean.to(device=device, dtype=dtype) latents_std = latents_std.to(device=device, dtype=dtype) init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std else: init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) if add_noise: shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype, text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) add_neg_time_ids = list( negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." ) elif expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) return add_time_ids, add_neg_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def denoising_start(self): return self._denoising_start @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, strength: float = 0.3, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): The image(s) to modify with the pipeline. strength (`float`, *optional*, defaults to 0.3): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of `denoising_start` being declared as an integer, the value of `strength` will be ignored. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. aesthetic_score (`float`, *optional*, defaults to 6.0): Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_aesthetic_score (`float`, *optional*, defaults to 2.5): Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. pag_scale (`float`, *optional*, defaults to 3.0): The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention guidance will not be used. pag_adaptive_scale (`float`, *optional*, defaults to 0.0): The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is used. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, strength, num_inference_steps, None, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._denoising_start = denoising_start self._interrupt = False self._pag_scale = pag_scale self._pag_adaptive_scale = pag_adaptive_scale # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. Prepare timesteps def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device, denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) add_noise = True if self.denoising_start is None else False # 6. Prepare latent variables if latents is None: latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, add_noise, ) # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) height, width = latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 8. Prepare added time ids & embeddings if negative_original_size is None: negative_original_size = original_size if negative_target_size is None: negative_target_size = target_size add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance ) add_text_embeds = self._prepare_perturbed_attention_guidance( add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance ) add_time_ids = self._prepare_perturbed_attention_guidance( add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) for i, image_embeds in enumerate(ip_adapter_image_embeds): negative_image_embeds = None if self.do_classifier_free_guidance: negative_image_embeds, image_embeds = image_embeds.chunk(2) if self.do_perturbed_attention_guidance: image_embeds = self._prepare_perturbed_attention_guidance( image_embeds, negative_image_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) image_embeds = image_embeds.to(device) ip_adapter_image_embeds[i] = image_embeds # 9. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 9.1 Apply denoising_end if ( self.denoising_end is not None and self.denoising_start is not None and denoising_value_valid(self.denoising_end) and denoising_value_valid(self.denoising_start) and self.denoising_start >= self.denoising_end ): raise ValueError( f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + f" {self.denoising_end} when using type float." ) elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 9.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) if self.do_perturbed_attention_guidance: original_attn_proc = self.unet.attn_processors self._set_pag_attn_processor( pag_applied_layers=self.pag_applied_layers, do_classifier_free_guidance=self.do_classifier_free_guidance, ) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL.Image import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from .pag_utils import PAGMixin if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AutoPipelineForInpainting >>> from diffusers.utils import load_image >>> pipe = AutoPipelineForInpainting.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", ... torch_dtype=torch.float16, ... variant="fp16", ... enable_pag=True, ... ) >>> pipe.to("cuda") >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" >>> init_image = load_image(img_url).convert("RGB") >>> mask_image = load_image(mask_url).convert("RGB") >>> prompt = "A majestic tiger sitting on a bench" >>> image = pipe( ... prompt=prompt, ... image=init_image, ... mask_image=mask_image, ... num_inference_steps=50, ... strength=0.80, ... pag_scale=0.3, ... ).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionXLPAGInpaintPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin, PAGMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config of `stabilityai/stable-diffusion-xl-refiner-1-0`. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "add_neg_time_ids", "mask", "masked_image_latents", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, image_encoder=image_encoder, feature_extractor=feature_extractor, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None self.set_pag_applied_layers(pag_applied_layers) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.check_inputs def check_inputs( self, prompt, prompt_2, image, mask_image, height, width, strength, callback_steps, output_type, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, padding_mask_crop=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( f"The mask image should be a PIL image when inpainting mask crop, but is of type" f" {type(mask_image)}." ) if output_type != "pil": raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None, timestep=None, is_strength_max=True, add_noise=True, return_noise=False, return_image_latents=False, ): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if (image is None or timestep is None) and not is_strength_max: raise ValueError( "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." "However, either the image or the noise timestep has not been provided." ) if image.shape[1] == 4: image_latents = image.to(device=device, dtype=dtype) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) elif return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) if latents is None and add_noise: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents elif add_noise: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma else: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = image_latents.to(device) outputs = (latents,) if return_noise: outputs += (noise,) if return_image_latents: outputs += (image_latents,) return outputs # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): dtype = image.dtype if self.vae.config.force_upcast: image = image.float() self.vae.to(dtype=torch.float32) if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) image_latents = image_latents.to(dtype) image_latents = self.vae.config.scaling_factor * image_latents return image_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask if masked_image is not None and masked_image.shape[1] == 4: masked_image_latents = masked_image else: masked_image_latents = None if masked_image is not None: if masked_image_latents is None: masked_image = masked_image.to(device=device, dtype=dtype) masked_image_latents = self._encode_vae_image(masked_image, generator=generator) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat( batch_size // masked_image_latents.shape[0], 1, 1, 1 ) masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) else: t_start = 0 timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] # Strength is irrelevant if we directly request a timestep to start at; # that is, strength is determined by the denoising_start instead. if denoising_start is not None: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() if self.scheduler.order == 2 and num_inference_steps % 2 == 0: # if the scheduler is a 2nd order scheduler we might have to do +1 # because `num_inference_steps` might be even given that every timestep # (except the highest one) is duplicated. If `num_inference_steps` is even it would # mean that we cut the timesteps in the middle of the denoising step # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler num_inference_steps = num_inference_steps + 1 # because t_n+1 >= t_n, we slice the timesteps starting from the end timesteps = timesteps[-num_inference_steps:] return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype, text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) add_neg_time_ids = list( negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." ) elif expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) return add_time_ids, add_neg_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def denoising_start(self): return self._denoising_start @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, mask_image: PipelineImageInput = None, masked_image_latents: torch.Tensor = None, height: Optional[int] = None, width: Optional[int] = None, padding_mask_crop: Optional[int] = None, strength: float = 0.9999, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. padding_mask_crop (`int`, *optional*, defaults to `None`): The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large and contain information irrelevant for inpainting, such as background. strength (`float`, *optional*, defaults to 0.9999): Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked portion of the reference `image`. Note that in the case of `denoising_start` being declared as an integer, the value of `strength` will be ignored. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. aesthetic_score (`float`, *optional*, defaults to 6.0): Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_aesthetic_score (`float`, *optional*, defaults to 2.5): Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. pag_scale (`float`, *optional*, defaults to 3.0): The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention guidance will not be used. pag_adaptive_scale (`float`, *optional*, defaults to 0.0): The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is used. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs self.check_inputs( prompt, prompt_2, image, mask_image, height, width, strength, None, output_type, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, padding_mask_crop, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._denoising_start = denoising_start self._interrupt = False self._pag_scale = pag_scale self._pag_adaptive_scale = pag_adaptive_scale # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # 4. set timesteps def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device, denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, ) # check that number of inference steps is not < 1 - as this doesn't make sense if num_inference_steps < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 # 5. Preprocess mask and image if padding_mask_crop is not None: crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) resize_mode = "fill" else: crops_coords = None resize_mode = "default" original_image = image init_image = self.image_processor.preprocess( image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode ) init_image = init_image.to(dtype=torch.float32) mask = self.mask_processor.preprocess( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) if masked_image_latents is not None: masked_image = masked_image_latents elif init_image.shape[1] == 4: # if images are in latent space, we can't mask it masked_image = None else: masked_image = init_image * (mask < 0.5) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 add_noise = True if self.denoising_start is None else False latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, add_noise=add_noise, return_noise=True, return_image_latents=return_image_latents, ) if return_image_latents: latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, self.do_classifier_free_guidance, ) # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: raise ValueError( f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." ) # 8.1 Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline height, width = latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 10. Prepare added time ids & embeddings if negative_original_size is None: negative_original_size = original_size if negative_target_size is None: negative_target_size = target_size add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) if self.do_perturbed_attention_guidance: prompt_embeds = self._prepare_perturbed_attention_guidance( prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance ) add_text_embeds = self._prepare_perturbed_attention_guidance( add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance ) add_time_ids = self._prepare_perturbed_attention_guidance( add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) for i, image_embeds in enumerate(ip_adapter_image_embeds): negative_image_embeds = None if self.do_classifier_free_guidance: negative_image_embeds, image_embeds = image_embeds.chunk(2) if self.do_perturbed_attention_guidance: image_embeds = self._prepare_perturbed_attention_guidance( image_embeds, negative_image_embeds, self.do_classifier_free_guidance ) elif self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) image_embeds = image_embeds.to(device) ip_adapter_image_embeds[i] = image_embeds # 11. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) if ( self.denoising_end is not None and self.denoising_start is not None and denoising_value_valid(self.denoising_end) and denoising_value_valid(self.denoising_start) and self.denoising_start >= self.denoising_end ): raise ValueError( f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + f" {self.denoising_end} when using type float." ) elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 11.1 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) if self.do_perturbed_attention_guidance: original_attn_proc = self.unet.attn_processors self._set_pag_attn_processor( pag_applied_layers=self.pag_applied_layers, do_classifier_free_guidance=self.do_classifier_free_guidance, ) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if num_channels_unet == 9: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_perturbed_attention_guidance: noise_pred = self._apply_perturbed_attention_guidance( noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if num_channels_unet == 4: init_latents_proper = image_latents if self.do_classifier_free_guidance: init_mask, _ = mask.chunk(2) else: init_mask = mask if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) mask = callback_outputs.pop("mask", mask) masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: return StableDiffusionXLPipelineOutput(images=latents) # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) if padding_mask_crop is not None: image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] # Offload all models self.maybe_free_model_hooks() if self.do_perturbed_attention_guidance: self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/paint_by_example/__init__.py ================================================ from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["image_encoder"] = ["PaintByExampleImageEncoder"] _import_structure["pipeline_paint_by_example"] = ["PaintByExamplePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .image_encoder import PaintByExampleImageEncoder from .pipeline_paint_by_example import PaintByExamplePipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/paint_by_example/image_encoder.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from transformers import CLIPPreTrainedModel, CLIPVisionModel from ...models.attention import BasicTransformerBlock from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name class PaintByExampleImageEncoder(CLIPPreTrainedModel): def __init__(self, config, proj_size=None): super().__init__(config) self.proj_size = proj_size or getattr(config, "projection_dim", 768) self.model = CLIPVisionModel(config) self.mapper = PaintByExampleMapper(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size) self.proj_out = nn.Linear(config.hidden_size, self.proj_size) # uncondition for scaling self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) def forward(self, pixel_values, return_uncond_vector=False): clip_output = self.model(pixel_values=pixel_values) latent_states = clip_output.pooler_output latent_states = self.mapper(latent_states[:, None]) latent_states = self.final_layer_norm(latent_states) latent_states = self.proj_out(latent_states) if return_uncond_vector: return latent_states, self.uncond_vector return latent_states class PaintByExampleMapper(nn.Module): def __init__(self, config): super().__init__() num_layers = (config.num_hidden_layers + 1) // 5 hid_size = config.hidden_size num_heads = 1 self.blocks = nn.ModuleList( [ BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) for _ in range(num_layers) ] ) def forward(self, hidden_states): for block in self.blocks: hidden_states = block(hidden_states) return hidden_states ================================================ FILE: src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .image_encoder import PaintByExampleImageEncoder logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") def prepare_mask_and_masked_image(image, mask): """ Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Batched mask if mask.shape[0] == image.shape[0]: mask = mask.unsqueeze(1) else: mask = mask.unsqueeze(0) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" assert mask.shape[1] == 1, "Mask image must have a single channel" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # paint-by-example inverses the mask mask = 1 - mask # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: if isinstance(image, PIL.Image.Image): image = [image] image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, PIL.Image.Image): mask = [mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 # paint-by-example inverses the mask mask = 1 - mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) masked_image = image * mask return mask, masked_image class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin): r""" 🧪 This is an experimental feature! Pipeline for image-guided image inpainting using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. image_encoder ([`PaintByExampleImageEncoder`]): Encodes the example input image. The `unet` is conditioned on the example image instead of a text prompt. tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ # TODO: feature_extractor is required to encode initial images (if they are in PIL format), # we should give a descriptive message if the pipeline doesn't have one. model_cpu_offload_seq = "unet->vae" _exclude_from_cpu_offload = ["image_encoder"] _optional_components = ["safety_checker"] def __init__( self, vae: AutoencoderKL, image_encoder: PaintByExampleImageEncoder, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = False, ): super().__init__() self.register_modules( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs def check_inputs(self, image, height, width, callback_steps): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) if masked_image.shape[1] == 4: masked_image_latents = masked_image else: masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = retrieve_latents(self.vae.encode(image), generator=generator) image_latents = self.vae.config.scaling_factor * image_latents return image_latents def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1) negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings @torch.no_grad() def __call__( self, example_image: Union[torch.Tensor, PIL.Image.Image], image: Union[torch.Tensor, PIL.Image.Image], mask_image: Union[torch.Tensor, PIL.Image.Image], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, ): r""" The call function to the pipeline for generation. Args: example_image (`torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): An example image to guide image generation. image (`torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): `Image` or tensor representing an image batch to be inpainted (parts of the image are masked out with `mask_image` and repainted according to `prompt`). mask_image (`torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted, while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Example: ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import PaintByExamplePipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = ( ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png" ... ) >>> mask_url = ( ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png" ... ) >>> example_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg" >>> init_image = download_image(img_url).resize((512, 512)) >>> mask_image = download_image(mask_url).resize((512, 512)) >>> example_image = download_image(example_url).resize((512, 512)) >>> pipe = PaintByExamplePipeline.from_pretrained( ... "Fantasy-Studio/Paint-by-Example", ... torch_dtype=torch.float16, ... ) >>> pipe = pipe.to("cuda") >>> image = pipe(image=init_image, mask_image=mask_image, example_image=example_image).images[0] >>> image ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 1. Define call parameters if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): batch_size = len(image) else: batch_size = image.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 2. Preprocess mask and image mask, masked_image = prepare_mask_and_masked_image(image, mask_image) height, width = masked_image.shape[-2:] # 3. Check inputs self.check_inputs(example_image, height, width, callback_steps) # 4. Encode input image image_embeddings = self._encode_image( example_image, device, num_images_per_prompt, do_classifier_free_guidance ) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, image_embeddings.dtype, device, generator, do_classifier_free_guidance, ) # 8. Check that sizes of mask, masked image and latents match num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input, masked_image_latents, mask], dim=1) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) self.maybe_free_model_hooks() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/pia/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_pia"] = ["PIAPipeline", "PIAPipelineOutput"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_pia import PIAPipeline, PIAPipelineOutput else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/pia/pipeline_pia.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from ...utils import ( USE_PEFT_BACKEND, BaseOutput, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import EulerDiscreteScheduler, MotionAdapter, PIAPipeline >>> from diffusers.utils import export_to_gif, load_image >>> adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter") >>> pipe = PIAPipeline.from_pretrained( ... "SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16 ... ) >>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) >>> image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true" ... ) >>> image = image.resize((512, 512)) >>> prompt = "cat in a hat" >>> negative_prompt = "wrong white balance, dark, sketches, worst quality, low quality, deformed, distorted" >>> generator = torch.Generator("cpu").manual_seed(0) >>> output = pipe(image=image, prompt=prompt, negative_prompt=negative_prompt, generator=generator) >>> frames = output.frames[0] >>> export_to_gif(frames, "pia-animation.gif") ``` """ RANGE_LIST = [ [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 Small Motion [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # Moderate Motion [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # Large Motion [1.0, 0.9, 0.85, 0.85, 0.85, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.85, 0.85, 0.9, 1.0], # Loop [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8, 0.8, 1.0], # Loop [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5, 0.6, 0.7, 0.7, 0.7, 0.7, 0.8, 1.0], # Loop [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion [0.5, 0.2], # Style Transfer Large Motion ] def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int): assert num_frames > 0, "video_length should be greater than 0" assert num_frames > cond_frame, "video_length should be greater than cond_frame" range_list = RANGE_LIST assert motion_scale < len(range_list), f"motion_scale type{motion_scale} not implemented" coef = range_list[motion_scale] coef = coef + ([coef[-1]] * (num_frames - len(coef))) order = [abs(i - cond_frame) for i in range(num_frames)] coef = [coef[order[i]] for i in range(num_frames)] return coef @dataclass class PIAPipelineOutput(BaseOutput): r""" Output class for PIAPipeline. Args: frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`, NumPy array of shape `(batch_size, num_frames, channels, height, width, Torch tensor of shape `(batch_size, num_frames, channels, height, width)`. """ frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] class PIAPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, FreeInitMixin, ): r""" Pipeline for text-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): A [`~transformers.CLIPTokenizer`] to tokenize text. unet ([`UNet2DConditionModel`]): A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. motion_adapter ([`MotionAdapter`]): A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: Union[UNet2DConditionModel, UNetMotionModel], scheduler: Union[ DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], motion_adapter: Optional[MotionAdapter] = None, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() if isinstance(unet, UNet2DConditionModel): unet = UNetMotionModel.from_unet2d(unet, motion_adapter) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, motion_adapter=motion_adapter, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_masked_condition( self, image, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, motion_scale=0, ): shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) _, _, _, scaled_height, scaled_width = shape image = self.video_processor.preprocess(image) image = image.to(device, dtype) if isinstance(generator, list): image_latent = [ self.vae.encode(image[k : k + 1]).latent_dist.sample(generator[k]) for k in range(batch_size) ] image_latent = torch.cat(image_latent, dim=0) else: image_latent = self.vae.encode(image).latent_dist.sample(generator) image_latent = image_latent.to(device=device, dtype=dtype) image_latent = torch.nn.functional.interpolate(image_latent, size=[scaled_height, scaled_width]) image_latent_padding = image_latent.clone() * self.vae.config.scaling_factor mask = torch.zeros((batch_size, 1, num_frames, scaled_height, scaled_width)).to(device=device, dtype=dtype) mask_coef = prepare_mask_coef_by_statistics(num_frames, 0, motion_scale) masked_image = torch.zeros(batch_size, 4, num_frames, scaled_height, scaled_width).to( device=device, dtype=self.unet.dtype ) for f in range(num_frames): mask[:, :, f, :, :] = mask_coef[f] masked_image[:, :, f, :, :] = image_latent_padding.clone() mask = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask masked_image = torch.cat([masked_image] * 2) if self.do_classifier_free_guidance else masked_image return mask, masked_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: PipelineImageInput, prompt: Union[str, List[str]] = None, strength: float = 1.0, num_frames: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, motion_scale: int = 0, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): r""" The call function to the pipeline for generation. Args: image (`PipelineImageInput`): The input image to be used for video generation. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated video. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated video. num_frames (`int`, *optional*, defaults to 16): The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. motion_scale: (`int`, *optional*, defaults to 0): Parameter that controls the amount and type of motion that is added to the image. Increasing the value increases the amount of motion, while specific ranges of values control the type of motion that is added. Must be between 0 and 8. Set between 0-2 to only increase the amount of motion. Set between 3-5 to create looping motion. Set between 6-8 to perform motion with image style transfer. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_videos_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) self._num_timesteps = len(timesteps) # 5. Prepare latent variables latents = self.prepare_latents( batch_size * num_videos_per_prompt, 4, num_frames, height, width, prompt_embeds.dtype, device, generator, latents=latents, ) mask, masked_image = self.prepare_masked_condition( image, batch_size * num_videos_per_prompt, 4, num_frames=num_frames, height=height, width=width, dtype=self.unet.dtype, device=device, generator=generator, motion_scale=motion_scale, ) if strength < 1.0: noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) latents = self.scheduler.add_noise(masked_image[0], noise, latent_timestep) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 8. Denoising loop num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): if self.free_init_enabled: latents, timesteps = self._apply_free_init( latents, free_init_iter, num_inference_steps, device, latents.dtype, generator ) self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=self._num_timesteps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input, mask, masked_image], dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() # 9. Post processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return PIAPipelineOutput(frames=video) ================================================ FILE: src/diffusers/pipelines/pipeline_flax_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect import os from typing import Any, Dict, List, Optional, Union import flax import numpy as np import PIL.Image from flax.core.frozen_dict import FrozenDict from huggingface_hub import create_repo, snapshot_download from huggingface_hub.utils import validate_hf_hub_args from PIL import Image from tqdm.auto import tqdm from ..configuration_utils import ConfigMixin from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin from ..utils import ( CONFIG_NAME, BaseOutput, PushToHubMixin, http_user_agent, is_transformers_available, logging, ) if is_transformers_available(): from transformers import FlaxPreTrainedModel INDEX_FILE = "diffusion_flax_model.bin" logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "FlaxModelMixin": ["save_pretrained", "from_pretrained"], "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"], "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], }, "transformers": { "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"], "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], "ProcessorMixin": ["save_pretrained", "from_pretrained"], "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], }, } ALL_IMPORTABLE_CLASSES = {} for library in LOADABLE_CLASSES: ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) def import_flax_or_no_model(module, class_name): try: # 1. First make sure that if a Flax object is present, import this one class_obj = getattr(module, "Flax" + class_name) except AttributeError: # 2. If this doesn't work, it's not a model and we don't append "Flax" class_obj = getattr(module, class_name) except AttributeError: raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}") return class_obj @flax.struct.dataclass class FlaxImagePipelineOutput(BaseOutput): """ Output class for image pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. """ images: Union[List[PIL.Image.Image], np.ndarray] class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): r""" Base class for Flax-based pipelines. [`FlaxDiffusionPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines and provides methods for loading, downloading and saving models. It also includes methods to: - enable/disable the progress bar for the denoising iteration Class attributes: - **config_name** ([`str`]) -- The configuration filename that stores the class and module names of all the diffusion pipeline's components. """ config_name = "model_index.json" def register_modules(self, **kwargs): # import it here to avoid circular import from diffusers import pipelines for name, module in kwargs.items(): if module is None: register_dict = {name: (None, None)} else: # retrieve library library = module.__module__.split(".")[0] # check if the module is a pipeline module pipeline_dir = module.__module__.split(".")[-2] path = module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline # folder so we set the library to module name. if library not in LOADABLE_CLASSES or is_pipeline_module: library = pipeline_dir # retrieve class_name class_name = module.__class__.__name__ register_dict = {name: (library, class_name)} # save model index config self.register_to_config(**register_dict) # set models setattr(self, name, module) def save_pretrained( self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict], push_to_hub: bool = False, **kwargs, ): # TODO: handle inference_state """ Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading method. The pipeline is easily reloaded using the [`~FlaxDiffusionPipeline.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ self.save_config(save_directory) model_index_dict = dict(self.config) model_index_dict.pop("_class_name") model_index_dict.pop("_diffusers_version") model_index_dict.pop("_module", None) if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", False) create_pr = kwargs.pop("create_pr", False) token = kwargs.pop("token", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) if sub_model is None: # edge case for saving a pipeline with safety_checker=None continue model_cls = sub_model.__class__ save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): library = importlib.import_module(library_name) for base_class, save_load_methods in library_classes.items(): class_candidate = getattr(library, base_class, None) if class_candidate is not None and issubclass(model_cls, class_candidate): # if we found a suitable base class in LOADABLE_CLASSES then grab its save method save_method_name = save_load_methods[0] break if save_method_name is not None: break save_method = getattr(sub_model, save_method_name) expects_params = "params" in set(inspect.signature(save_method).parameters.keys()) if expects_params: save_method( os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name] ) else: save_method(os.path.join(save_directory, pipeline_component_name)) if push_to_hub: self._upload_folder( save_directory, repo_id, token=token, commit_message=commit_message, create_pr=create_pr, ) @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a Flax-based diffusion pipeline from pretrained pipeline weights. The pipeline is set in evaluation mode (`model.eval()) by default and dropout modules are deactivated. If you get the error message below, you need to finetune the weights for your downstream task: ``` Some weights of FlaxUNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: ``` Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *repo id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained pipeline hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved using [`~FlaxDiffusionPipeline.save_pretrained`]. dtype (`str` or `jnp.dtype`, *optional*): Override the default `jnp.dtype` and load the model under this dtype. If `"auto"`, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components) of the specific pipeline class. The overwritten components are passed directly to the pipelines `__init__` method. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. Examples: ```py >>> from diffusers import FlaxDiffusionPipeline >>> # Download pipeline from huggingface.co and cache. >>> # Requires to be logged in to Hugging Face hub, >>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens) >>> pipeline, params = FlaxDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", ... variant="bf16", ... dtype=jnp.bfloat16, ... ) >>> # Download pipeline, but use a different scheduler >>> from diffusers import FlaxDPMSolverMultistepScheduler >>> model_id = "runwayml/stable-diffusion-v1-5" >>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained( ... model_id, ... subfolder="scheduler", ... ) >>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained( ... model_id, variant="bf16", dtype=jnp.bfloat16, scheduler=dpmpp ... ) >>> dpm_params["scheduler"] = dpmpp_state ``` """ cache_dir = kwargs.pop("cache_dir", None) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) from_pt = kwargs.pop("from_pt", False) use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) split_head_dim = kwargs.pop("split_head_dim", False) dtype = kwargs.pop("dtype", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): config_dict = cls.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, ) # make sure we only download sub-folders and `diffusers` filenames folder_names = [k for k in config_dict.keys() if not k.startswith("_")] allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] ignore_patterns = ["*.bin", "*.safetensors"] if not from_pt else [] ignore_patterns += ["*.onnx", "*.onnx_data", "*.xml", "*.pb"] if cls != FlaxDiffusionPipeline: requested_pipeline_class = cls.__name__ else: requested_pipeline_class = config_dict.get("_class_name", cls.__name__) requested_pipeline_class = ( requested_pipeline_class if requested_pipeline_class.startswith("Flax") else "Flax" + requested_pipeline_class ) user_agent = {"pipeline_class": requested_pipeline_class} user_agent = http_user_agent(user_agent) # download all allow_patterns cached_folder = snapshot_download( pretrained_model_name_or_path, cache_dir=cache_dir, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, user_agent=user_agent, ) else: cached_folder = pretrained_model_name_or_path config_dict = cls.load_config(cached_folder) # 2. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it if cls != FlaxDiffusionPipeline: pipeline_class = cls else: diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) class_name = ( config_dict["_class_name"] if config_dict["_class_name"].startswith("Flax") else "Flax" + config_dict["_class_name"] ) pipeline_class = getattr(diffusers_module, class_name) # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) # define init kwargs init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} init_kwargs = {**init_kwargs, **passed_pipe_kwargs} # remove `null` components def load_module(name, value): if value[0] is None: return False if name in passed_class_obj and passed_class_obj[name] is None: return False return True init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} # Throw nice warnings / errors for fast accelerate loading if len(unused_kwargs) > 0: logger.warning( f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." ) # inference_params params = {} # import it here to avoid circular import from diffusers import pipelines # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): if class_name is None: # edge case for when the pipeline was saved with safety_checker=None init_kwargs[name] = None continue is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None sub_model_should_be_defined = True # if the model is in a pipeline module, then we load it from the pipeline if name in passed_class_obj: # 1. check that passed_class_obj has correct parent class if not is_pipeline_module: library = importlib.import_module(library_name) class_obj = getattr(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} expected_class_obj = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate if not issubclass(passed_class_obj[name].__class__, expected_class_obj): raise ValueError( f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" ) elif passed_class_obj[name] is None: logger.warning( f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" f" that this might lead to problems when using {pipeline_class} and is not recommended." ) sub_model_should_be_defined = False else: logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" " has the correct type" ) # set passed class object loaded_sub_model = passed_class_obj[name] elif is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = import_flax_or_no_model(pipeline_module, class_name) importable_classes = ALL_IMPORTABLE_CLASSES class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) class_obj = import_flax_or_no_model(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} if loaded_sub_model is None and sub_model_should_be_defined: load_method_name = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): load_method_name = importable_classes[class_name][1] load_method = getattr(class_obj, load_method_name) # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): loadable_folder = os.path.join(cached_folder, name) else: loaded_sub_model = cached_folder if issubclass(class_obj, FlaxModelMixin): loaded_sub_model, loaded_params = load_method( loadable_folder, from_pt=from_pt, use_memory_efficient_attention=use_memory_efficient_attention, split_head_dim=split_head_dim, dtype=dtype, ) params[name] = loaded_params elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): if from_pt: # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) loaded_params = loaded_sub_model.params del loaded_sub_model._params else: loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) params[name] = loaded_params elif issubclass(class_obj, FlaxSchedulerMixin): loaded_sub_model, scheduler_state = load_method(loadable_folder) params[name] = scheduler_state else: loaded_sub_model = load_method(loadable_folder) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) # 4. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) if len(missing_modules) > 0 and missing_modules <= set(passed_modules): for module in missing_modules: init_kwargs[module] = passed_class_obj.get(module, None) elif len(missing_modules) > 0: passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs raise ValueError( f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) model = pipeline_class(**init_kwargs, dtype=dtype) return model, params @classmethod def _get_signature_keys(cls, obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) expected_modules = set(required_parameters.keys()) - {"self"} return expected_modules, optional_parameters @property def components(self) -> Dict[str, Any]: r""" The `self.components` property can be useful to run different pipelines with the same weights and configurations to not have to re-allocate memory. Examples: ```py >>> from diffusers import ( ... FlaxStableDiffusionPipeline, ... FlaxStableDiffusionImg2ImgPipeline, ... ) >>> text2img = FlaxStableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16 ... ) >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components) ``` Returns: A dictionary containing all the modules needed to initialize the pipeline. """ expected_modules, optional_parameters = self._get_signature_keys(self) components = { k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters } if set(components.keys()) != expected_modules: raise ValueError( f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" f" {expected_modules} to be defined, but {components} are defined." ) return components @staticmethod def numpy_to_pil(images): """ Convert a NumPy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images # TODO: make it compatible with jax.lax def progress_bar(self, iterable): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): raise ValueError( f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) return tqdm(iterable, **self._progress_bar_config) def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs ================================================ FILE: src/diffusers/pipelines/pipeline_loading_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import os import re import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Union import torch from huggingface_hub import model_info from huggingface_hub.utils import validate_hf_hub_args from packaging import version from .. import __version__ from ..utils import ( FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, get_class_from_dynamic_module, is_accelerate_available, is_peft_available, is_transformers_available, logging, ) from ..utils.torch_utils import is_compiled_module if is_transformers_available(): import transformers from transformers import PreTrainedModel from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME if is_accelerate_available(): import accelerate from accelerate import dispatch_model from accelerate.hooks import remove_hook_from_module from accelerate.utils import compute_module_sizes, get_max_memory INDEX_FILE = "diffusion_pytorch_model.bin" CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" DUMMY_MODULES_FOLDER = "diffusers.utils" TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" CONNECTED_PIPES_KEYS = ["prior"] logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], "SchedulerMixin": ["save_pretrained", "from_pretrained"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], }, "transformers": { "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"], "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], "ProcessorMixin": ["save_pretrained", "from_pretrained"], "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], }, "onnxruntime.training": { "ORTModule": ["save_pretrained", "from_pretrained"], }, } ALL_IMPORTABLE_CLASSES = {} for library in LOADABLE_CLASSES: ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: """ Checking for safetensors compatibility: - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch files to know which safetensors files are needed. - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. Converting default pytorch serialized filenames to safetensors serialized filenames: - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" extension is replaced with ".safetensors" """ pt_filenames = [] sf_filenames = set() passed_components = passed_components or [] for filename in filenames: _, extension = os.path.splitext(filename) if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: continue if extension == ".bin": pt_filenames.append(os.path.normpath(filename)) elif extension == ".safetensors": sf_filenames.add(os.path.normpath(filename)) for filename in pt_filenames: # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam' path, filename = os.path.split(filename) filename, extension = os.path.splitext(filename) if filename.startswith("pytorch_model"): filename = filename.replace("pytorch_model", "model") else: filename = filename expected_sf_filename = os.path.normpath(os.path.join(path, filename)) expected_sf_filename = f"{expected_sf_filename}.safetensors" if expected_sf_filename not in sf_filenames: logger.warning(f"{expected_sf_filename} not found") return False return True def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: weight_names = [ WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ] if is_transformers_available(): weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] # model_pytorch, diffusion_model_pytorch, ... weight_prefixes = [w.split(".")[0] for w in weight_names] # .bin, .safetensors, ... weight_suffixs = [w.split(".")[-1] for w in weight_names] # -00001-of-00002 transformers_index_format = r"\d{5}-of-\d{5}" if variant is not None: # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors` variant_file_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" ) # `text_encoder/pytorch_model.bin.index.fp16.json` variant_index_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" ) # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` non_variant_file_re = re.compile( rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" ) # `text_encoder/pytorch_model.bin.index.json` non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") if variant is not None: variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} variant_filenames = variant_weights | variant_indexes else: variant_filenames = set() non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} non_variant_filenames = non_variant_weights | non_variant_indexes # all variant filenames will be used by default usable_filenames = set(variant_filenames) def convert_to_variant(filename): if "index" in filename: variant_filename = filename.replace("index", f"index.{variant}") elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" else: variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" return variant_filename for f in non_variant_filenames: variant_filename = convert_to_variant(f) if variant_filename not in usable_filenames: usable_filenames.add(f) return usable_filenames, variant_filenames @validate_hf_hub_args def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames): info = model_info( pretrained_model_name_or_path, token=token, revision=None, ) filenames = {sibling.rfilename for sibling in info.siblings} comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] if set(model_filenames).issubset(set(comp_model_filenames)): warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", FutureWarning, ) else: warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", FutureWarning, ) def _unwrap_model(model): """Unwraps a model.""" if is_compiled_module(model): model = model._orig_mod if is_peft_available(): from peft import PeftModel if isinstance(model, PeftModel): model = model.base_model.model return model def maybe_raise_or_warn( library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module ): """Simple helper method to raise or warn in case incorrect module has been passed""" if not is_pipeline_module: library = importlib.import_module(library_name) class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} expected_class_obj = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. sub_model = passed_class_obj[name] unwrapped_sub_model = _unwrap_model(sub_model) model_cls = unwrapped_sub_model.__class__ if not issubclass(model_cls, expected_class_obj): raise ValueError( f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" ) else: logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" " has the correct type" ) def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): """Simple helper method to retrieve class object of module as well as potential parent class objects""" component_folder = os.path.join(cache_dir, component_name) if is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) class_candidates = {c: class_obj for c in importable_classes.keys()} elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): # load custom component class_obj = get_class_from_dynamic_module( component_folder, module_file=library_name + ".py", class_name=class_name ) class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} return class_obj, class_candidates def _get_custom_pipeline_class( custom_pipeline, repo_id=None, hub_revision=None, class_name=None, cache_dir=None, revision=None, ): if custom_pipeline.endswith(".py"): path = Path(custom_pipeline) # decompose into folder & file file_name = path.name custom_pipeline = path.parent.absolute() elif repo_id is not None: file_name = f"{custom_pipeline}.py" custom_pipeline = repo_id else: file_name = CUSTOM_PIPELINE_FILE_NAME if repo_id is not None and hub_revision is not None: # if we load the pipeline code from the Hub # make sure to overwrite the `revision` revision = hub_revision return get_class_from_dynamic_module( custom_pipeline, module_file=file_name, class_name=class_name, cache_dir=cache_dir, revision=revision, ) def _get_pipeline_class( class_obj, config=None, load_connected_pipeline=False, custom_pipeline=None, repo_id=None, hub_revision=None, class_name=None, cache_dir=None, revision=None, ): if custom_pipeline is not None: return _get_custom_pipeline_class( custom_pipeline, repo_id=repo_id, hub_revision=hub_revision, class_name=class_name, cache_dir=cache_dir, revision=revision, ) if class_obj.__name__ != "DiffusionPipeline": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) class_name = class_name or config["_class_name"] if not class_name: raise ValueError( "The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`." ) class_name = class_name[4:] if class_name.startswith("Flax") else class_name pipeline_cls = getattr(diffusers_module, class_name) if load_connected_pipeline: from .auto_pipeline import _get_connected_pipeline connected_pipeline_cls = _get_connected_pipeline(pipeline_cls) if connected_pipeline_cls is not None: logger.info( f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`" ) else: logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.") pipeline_cls = connected_pipeline_cls or pipeline_cls return pipeline_cls def _load_empty_model( library_name: str, class_name: str, importable_classes: List[Any], pipelines: Any, is_pipeline_module: bool, name: str, torch_dtype: Union[str, torch.dtype], cached_folder: Union[str, os.PathLike], **kwargs, ): # retrieve class objects. class_obj, _ = get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=name, cache_dir=cached_folder, ) if is_transformers_available(): transformers_version = version.parse(version.parse(transformers.__version__).base_version) else: transformers_version = "N/A" # Determine library. is_transformers_model = ( is_transformers_available() and issubclass(class_obj, PreTrainedModel) and transformers_version >= version.parse("4.20.0") ) diffusers_module = importlib.import_module(__name__.split(".")[0]) is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) model = None config_path = cached_folder user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } if is_diffusers_model: # Load config and then the model on meta. config, unused_kwargs, commit_hash = class_obj.load_config( os.path.join(config_path, name), cache_dir=cached_folder, return_unused_kwargs=True, return_commit_hash=True, force_download=kwargs.pop("force_download", False), proxies=kwargs.pop("proxies", None), local_files_only=kwargs.pop("local_files_only", False), token=kwargs.pop("token", None), revision=kwargs.pop("revision", None), subfolder=kwargs.pop("subfolder", None), user_agent=user_agent, ) with accelerate.init_empty_weights(): model = class_obj.from_config(config, **unused_kwargs) elif is_transformers_model: config_class = getattr(class_obj, "config_class", None) if config_class is None: raise ValueError("`config_class` cannot be None. Please double-check the model.") config = config_class.from_pretrained( cached_folder, subfolder=name, force_download=kwargs.pop("force_download", False), proxies=kwargs.pop("proxies", None), local_files_only=kwargs.pop("local_files_only", False), token=kwargs.pop("token", None), revision=kwargs.pop("revision", None), user_agent=user_agent, ) with accelerate.init_empty_weights(): model = class_obj(config) if model is not None: model = model.to(dtype=torch_dtype) return model def _assign_components_to_devices( module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced" ): device_ids = list(device_memory.keys()) device_cycle = device_ids + device_ids[::-1] device_memory = device_memory.copy() device_id_component_mapping = {} current_device_index = 0 for component in module_sizes: device_id = device_cycle[current_device_index % len(device_cycle)] component_memory = module_sizes[component] curr_device_memory = device_memory[device_id] # If the GPU doesn't fit the current component offload to the CPU. if component_memory > curr_device_memory: device_id_component_mapping["cpu"] = [component] else: if device_id not in device_id_component_mapping: device_id_component_mapping[device_id] = [component] else: device_id_component_mapping[device_id].append(component) # Update the device memory. device_memory[device_id] -= component_memory current_device_index += 1 return device_id_component_mapping def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs): # To avoid circular import problem. from diffusers import pipelines torch_dtype = kwargs.get("torch_dtype", torch.float32) # Load each module in the pipeline on a meta device so that we can derive the device map. init_empty_modules = {} for name, (library_name, class_name) in init_dict.items(): if class_name.startswith("Flax"): raise ValueError("Flax pipelines are not supported with `device_map`.") # Define all importable classes is_pipeline_module = hasattr(pipelines, library_name) importable_classes = ALL_IMPORTABLE_CLASSES loaded_sub_model = None # Use passed sub model or load class_name from library_name if name in passed_class_obj: # if the model is in a pipeline module, then we load it from the pipeline # check that passed_class_obj has correct parent class maybe_raise_or_warn( library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module, ) with accelerate.init_empty_weights(): loaded_sub_model = passed_class_obj[name] else: loaded_sub_model = _load_empty_model( library_name=library_name, class_name=class_name, importable_classes=importable_classes, pipelines=pipelines, is_pipeline_module=is_pipeline_module, pipeline_class=pipeline_class, name=name, torch_dtype=torch_dtype, cached_folder=kwargs.get("cached_folder", None), force_download=kwargs.get("force_download", None), proxies=kwargs.get("proxies", None), local_files_only=kwargs.get("local_files_only", None), token=kwargs.get("token", None), revision=kwargs.get("revision", None), ) if loaded_sub_model is not None: init_empty_modules[name] = loaded_sub_model # determine device map # Obtain a sorted dictionary for mapping the model-level components # to their sizes. module_sizes = { module_name: compute_module_sizes(module, dtype=torch_dtype)[""] for module_name, module in init_empty_modules.items() if isinstance(module, torch.nn.Module) } module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True)) # Obtain maximum memory available per device (GPUs only). max_memory = get_max_memory(max_memory) max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True)) max_memory = {k: v for k, v in max_memory.items() if k != "cpu"} # Obtain a dictionary mapping the model-level components to the available # devices based on the maximum memory and the model sizes. final_device_map = None if len(max_memory) > 0: device_id_component_mapping = _assign_components_to_devices( module_sizes, max_memory, device_mapping_strategy=device_map ) # Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}` final_device_map = {} for device_id, components in device_id_component_mapping.items(): for component in components: final_device_map[component] = device_id return final_device_map def load_sub_model( library_name: str, class_name: str, importable_classes: List[Any], pipelines: Any, is_pipeline_module: bool, pipeline_class: Any, torch_dtype: torch.dtype, provider: Any, sess_options: Any, device_map: Optional[Union[Dict[str, torch.device], str]], max_memory: Optional[Dict[Union[int, str], Union[int, str]]], offload_folder: Optional[Union[str, os.PathLike]], offload_state_dict: bool, model_variants: Dict[str, str], name: str, from_flax: bool, variant: str, low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], ): """Helper method to load the module `name` from `library_name` and `class_name`""" # retrieve class candidates class_obj, class_candidates = get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=name, cache_dir=cached_folder, ) load_method_name = None # retrieve load method name for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): load_method_name = importable_classes[class_name][1] # if load method name is None, then we have a dummy module -> raise Error if load_method_name is None: none_module = class_obj.__module__ is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( TRANSFORMERS_DUMMY_MODULES_FOLDER ) if is_dummy_path and "dummy" in none_module: # call class_obj for nice error message of missing requirements class_obj() raise ValueError( f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." ) load_method = getattr(class_obj, load_method_name) # add kwargs to loading method diffusers_module = importlib.import_module(__name__.split(".")[0]) loading_kwargs = {} if issubclass(class_obj, torch.nn.Module): loading_kwargs["torch_dtype"] = torch_dtype if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): loading_kwargs["provider"] = provider loading_kwargs["sess_options"] = sess_options is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) if is_transformers_available(): transformers_version = version.parse(version.parse(transformers.__version__).base_version) else: transformers_version = "N/A" is_transformers_model = ( is_transformers_available() and issubclass(class_obj, PreTrainedModel) and transformers_version >= version.parse("4.20.0") ) # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. # This makes sure that the weights won't be initialized which significantly speeds up loading. if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map loading_kwargs["max_memory"] = max_memory loading_kwargs["offload_folder"] = offload_folder loading_kwargs["offload_state_dict"] = offload_state_dict loading_kwargs["variant"] = model_variants.pop(name, None) if from_flax: loading_kwargs["from_flax"] = True # the following can be deleted once the minimum required `transformers` version # is higher than 4.27 if ( is_transformers_model and loading_kwargs["variant"] is not None and transformers_version < version.parse("4.27.0") ): raise ImportError( f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" ) elif is_transformers_model and loading_kwargs["variant"] is None: loading_kwargs.pop("variant") # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` if not (from_flax and is_transformers_model): loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage else: loading_kwargs["low_cpu_mem_usage"] = False # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) else: # else load from the root directory loaded_sub_model = load_method(cached_folder, **loading_kwargs) if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict): # remove hooks remove_hook_from_module(loaded_sub_model, recurse=True) needs_offloading_to_cpu = device_map[""] == "cpu" if needs_offloading_to_cpu: dispatch_model( loaded_sub_model, state_dict=loaded_sub_model.state_dict(), device_map=device_map, force_hooks=True, main_device=0, ) else: dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True) return loaded_sub_model def _fetch_class_library_tuple(module): # import it here to avoid circular import diffusers_module = importlib.import_module(__name__.split(".")[0]) pipelines = getattr(diffusers_module, "pipelines") # register the config from the original module, not the dynamo compiled one not_compiled_module = _unwrap_model(module) library = not_compiled_module.__module__.split(".")[0] # check if the module is a pipeline module module_path_items = not_compiled_module.__module__.split(".") pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None path = not_compiled_module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline # folder so we set the library to module name. if is_pipeline_module: library = pipeline_dir elif library not in LOADABLE_CLASSES: library = not_compiled_module.__module__ # retrieve class_name class_name = not_compiled_module.__class__.__name__ return (library, class_name) ================================================ FILE: src/diffusers/pipelines/pipeline_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import fnmatch import importlib import inspect import os import re import sys from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin import numpy as np import PIL.Image import requests import torch from huggingface_hub import ( ModelCard, create_repo, hf_hub_download, model_info, snapshot_download, ) from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args from packaging import version from requests.exceptions import HTTPError from tqdm.auto import tqdm from .. import __version__ from ..configuration_utils import ConfigMixin from ..models import AutoencoderKL from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, DEPRECATED_REVISION_ARGS, BaseOutput, PushToHubMixin, deprecate, is_accelerate_available, is_accelerate_version, is_torch_npu_available, is_torch_version, logging, numpy_to_pil, ) from ..utils.hub_utils import load_or_create_model_card, populate_model_card from ..utils.torch_utils import is_compiled_module if is_torch_npu_available(): import torch_npu # noqa: F401 from .pipeline_loading_utils import ( ALL_IMPORTABLE_CLASSES, CONNECTED_PIPES_KEYS, CUSTOM_PIPELINE_FILE_NAME, LOADABLE_CLASSES, _fetch_class_library_tuple, _get_custom_pipeline_class, _get_final_device_map, _get_pipeline_class, _unwrap_model, is_safetensors_compatible, load_sub_model, maybe_raise_or_warn, variant_compatible_siblings, warn_deprecated_model_variant, ) if is_accelerate_available(): import accelerate LIBRARIES = [] for library in LOADABLE_CLASSES: LIBRARIES.append(library) SUPPORTED_DEVICE_MAP = ["balanced"] logger = logging.get_logger(__name__) @dataclass class ImagePipelineOutput(BaseOutput): """ Output class for image pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. """ images: Union[List[PIL.Image.Image], np.ndarray] @dataclass class AudioPipelineOutput(BaseOutput): """ Output class for audio pipelines. Args: audios (`np.ndarray`) List of denoised audio samples of a NumPy array of shape `(batch_size, num_channels, sample_rate)`. """ audios: np.ndarray class DiffusionPipeline(ConfigMixin, PushToHubMixin): r""" Base class for all pipelines. [`DiffusionPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines and provides methods for loading, downloading and saving models. It also includes methods to: - move all PyTorch modules to the device of your choice - enable/disable the progress bar for the denoising iteration Class attributes: - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the diffusion pipeline's components. - **_optional_components** (`List[str]`) -- List of all optional components that don't have to be passed to the pipeline to function (should be overridden by subclasses). """ config_name = "model_index.json" model_cpu_offload_seq = None hf_device_map = None _optional_components = [] _exclude_from_cpu_offload = [] _load_connected_pipes = False _is_onnx = False def register_modules(self, **kwargs): for name, module in kwargs.items(): # retrieve library if module is None or isinstance(module, (tuple, list)) and module[0] is None: register_dict = {name: (None, None)} else: library, class_name = _fetch_class_library_tuple(module) register_dict = {name: (library, class_name)} # save model index config self.register_to_config(**register_dict) # set models setattr(self, name, module) def __setattr__(self, name: str, value: Any): if name in self.__dict__ and hasattr(self.config, name): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): if value is not None and self.config[name][0] is not None: class_library_tuple = _fetch_class_library_tuple(value) else: class_library_tuple = (None, None) self.register_to_config(**{name: class_library_tuple}) else: self.register_to_config(**{name: value}) super().__setattr__(name, value) def save_pretrained( self, save_directory: Union[str, os.PathLike], safe_serialization: bool = True, variant: Optional[str] = None, push_to_hub: bool = False, **kwargs, ): """ Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading method. The pipeline is easily reloaded using the [`~DiffusionPipeline.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save a pipeline to. Will be created if it doesn't exist. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ model_index_dict = dict(self.config) model_index_dict.pop("_class_name", None) model_index_dict.pop("_diffusers_version", None) model_index_dict.pop("_module", None) model_index_dict.pop("_name_or_path", None) if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", False) create_pr = kwargs.pop("create_pr", False) token = kwargs.pop("token", None) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id expected_modules, optional_kwargs = self._get_signature_keys(self) def is_saveable_module(name, value): if name not in expected_modules: return False if name in self._optional_components and value[0] is None: return False return True model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) model_cls = sub_model.__class__ # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. if is_compiled_module(sub_model): sub_model = _unwrap_model(sub_model) model_cls = sub_model.__class__ save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): if library_name in sys.modules: library = importlib.import_module(library_name) else: logger.info( f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}" ) for base_class, save_load_methods in library_classes.items(): class_candidate = getattr(library, base_class, None) if class_candidate is not None and issubclass(model_cls, class_candidate): # if we found a suitable base class in LOADABLE_CLASSES then grab its save method save_method_name = save_load_methods[0] break if save_method_name is not None: break if save_method_name is None: logger.warning( f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved." ) # make sure that unsaveable components are not tried to be loaded afterward self.register_to_config(**{pipeline_component_name: (None, None)}) continue save_method = getattr(sub_model, save_method_name) # Call the save method with the argument safe_serialization only if it's supported save_method_signature = inspect.signature(save_method) save_method_accept_safe = "safe_serialization" in save_method_signature.parameters save_method_accept_variant = "variant" in save_method_signature.parameters save_kwargs = {} if save_method_accept_safe: save_kwargs["safe_serialization"] = safe_serialization if save_method_accept_variant: save_kwargs["variant"] = variant save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) # finally save the config self.save_config(save_directory) if push_to_hub: # Create a new empty model card and eventually tag it model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True) model_card = populate_model_card(model_card) model_card.save(os.path.join(save_directory, "README.md")) self._upload_folder( save_directory, repo_id, token=token, commit_message=commit_message, create_pr=create_pr, ) def to(self, *args, **kwargs): r""" Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the arguments of `self.to(*args, **kwargs).` If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise, the returned pipeline is a copy of self with the desired torch.dtype and torch.device. Here are the ways to call `to`: - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) Arguments: dtype (`torch.dtype`, *optional*): Returns a pipeline with the specified [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype) device (`torch.Device`, *optional*): Returns a pipeline with the specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) silence_dtype_warnings (`str`, *optional*, defaults to `False`): Whether to omit warnings if the target `dtype` is not compatible with the target `device`. Returns: [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. """ dtype = kwargs.pop("dtype", None) device = kwargs.pop("device", None) silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) dtype_arg = None device_arg = None if len(args) == 1: if isinstance(args[0], torch.dtype): dtype_arg = args[0] else: device_arg = torch.device(args[0]) if args[0] is not None else None elif len(args) == 2: if isinstance(args[0], torch.dtype): raise ValueError( "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." ) device_arg = torch.device(args[0]) if args[0] is not None else None dtype_arg = args[1] elif len(args) > 2: raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") if dtype is not None and dtype_arg is not None: raise ValueError( "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." ) dtype = dtype or dtype_arg if device is not None and device_arg is not None: raise ValueError( "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." ) device = device or device_arg # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. def module_is_sequentially_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): return False return hasattr(module, "_hf_hook") and ( isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook) or hasattr(module._hf_hook, "hooks") and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook) ) def module_is_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): return False return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() ) if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": raise ValueError( "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." ) is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`." ) # Display a warning in this case (the operation succeeds but the benefits are lost) pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) if pipeline_is_offloaded and device and torch.device(device).type == "cuda": logger.warning( f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." ) module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit if is_loaded_in_8bit and dtype is not None: logger.warning( f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." ) if is_loaded_in_8bit and device is not None: logger.warning( f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." ) else: module.to(device, dtype) if ( module.dtype == torch.float16 and str(device) in ["cpu"] and not silence_dtype_warnings and not is_offloaded ): logger.warning( "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" " is not recommended to move them to `cpu` as running them will fail. Please make" " sure to use an accelerator to run the pipeline in inference, due to the lack of" " support for`float16` operations on this device in PyTorch. Please, remove the" " `torch_dtype=torch.float16` argument, or use another device for inference." ) return self @property def device(self) -> torch.device: r""" Returns: `torch.device`: The torch device on which the pipeline is located. """ module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] for module in modules: return module.device return torch.device("cpu") @property def dtype(self) -> torch.dtype: r""" Returns: `torch.dtype`: The torch dtype on which the pipeline is located. """ module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] for module in modules: return module.dtype return torch.float32 @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights. The pipeline is set in evaluation mode (`model.eval()`) by default. If you get the error message below, you need to finetune the weights for your downstream task: ``` Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. custom_pipeline (`str`, *optional*): 🧪 This is an experimental feature and may change in the future. Can be either: - A string, the *repo id* (for example `hf-internal-testing/diffusers-dummy-pipeline`) of a custom pipeline hosted on the Hub. The repository must contain a file called pipeline.py that defines the custom pipeline. - A string, the *file name* of a community pipeline hosted on GitHub under [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file names must match the file name and not the pipeline script (`clip_guided_stable_diffusion` instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the current main branch of GitHub. - A path to a directory (`./my_pipeline_directory/`) containing a custom pipeline. The directory must contain a file called `pipeline.py` that defines the custom pipeline. For more information on how to load and create custom pipelines, please have a look at [Loading and Adding Custom Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. custom_revision (`str`, *optional*): The specific model version to use. It can be a branch name, a tag name, or a commit id similar to `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers version. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn’t need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if device_map contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. use_onnx (`bool`, *optional*, defaults to `None`): If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending with `.onnx` and `.pb`. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example below for more information. variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `huggingface-cli login`. Examples: ```py >>> from diffusers import DiffusionPipeline >>> # Download pipeline from huggingface.co and cache. >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") >>> # Download pipeline that requires an authorization token >>> # For more information on access tokens, please refer to this section >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # Use a different scheduler >>> from diffusers import LMSDiscreteScheduler >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) >>> pipeline.scheduler = scheduler ``` """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) from_flax = kwargs.pop("from_flax", False) torch_dtype = kwargs.pop("torch_dtype", None) custom_pipeline = kwargs.pop("custom_pipeline", None) custom_revision = kwargs.pop("custom_revision", None) provider = kwargs.pop("provider", None) sess_options = kwargs.pop("sess_options", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" " install accelerate\n```\n." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." ) if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) if device_map is not None and not is_accelerate_available(): raise NotImplementedError( "Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`." ) if device_map is not None and not isinstance(device_map, str): raise ValueError("`device_map` must be a string.") if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP: raise NotImplementedError( f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}" ) if device_map is not None and device_map in SUPPORTED_DEVICE_MAP: if is_accelerate_version("<", "0.28.0"): raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.") if low_cpu_mem_usage is False and device_map is not None: raise ValueError( f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): if pretrained_model_name_or_path.count("/") > 1: raise ValueError( f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"' " is neither a valid local path nor a valid repo id. Please check the parameter." ) cached_folder = cls.download( pretrained_model_name_or_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, from_flax=from_flax, use_safetensors=use_safetensors, use_onnx=use_onnx, custom_pipeline=custom_pipeline, custom_revision=custom_revision, variant=variant, load_connected_pipeline=load_connected_pipeline, **kwargs, ) else: cached_folder = pretrained_model_name_or_path config_dict = cls.load_config(cached_folder) # pop out "_ignore_files" as it is only needed for download config_dict.pop("_ignore_files", None) # 2. Define which model components should load variants # We retrieve the information by matching whether variant # model checkpoints exist in the subfolders model_variants = {} if variant is not None: for folder in os.listdir(cached_folder): folder_path = os.path.join(cached_folder, folder) is_folder = os.path.isdir(folder_path) and folder in config_dict variant_exists = is_folder and any( p.split(".")[1].startswith(variant) for p in os.listdir(folder_path) ) if variant_exists: model_variants[folder] = variant # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it custom_class_name = None if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")): custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py") elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile( os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py") ): custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py") custom_class_name = config_dict["_class_name"][1] pipeline_class = _get_pipeline_class( cls, config_dict, load_connected_pipeline=load_connected_pipeline, custom_pipeline=custom_pipeline, class_name=custom_class_name, cache_dir=cache_dir, revision=custom_revision, ) if device_map is not None and pipeline_class._load_connected_pipes: raise NotImplementedError("`device_map` is not yet supported for connected pipelines.") # DEPRECATED: To be removed in 1.0.0 if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( version.parse(config_dict["_diffusers_version"]).base_version ) <= version.parse("0.5.1"): from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy pipeline_class = StableDiffusionInpaintPipelineLegacy deprecation_message = ( "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" f" checkpoint {pretrained_model_name_or_path} to the format of" " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." ) deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) # 4. Define expected modules given pipeline signature # and define non-None initialized modules (=`init_kwargs`) # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) # define init kwargs and make sure that optional component modules are filtered out init_kwargs = { k: init_dict.pop(k) for k in optional_kwargs if k in init_dict and k not in pipeline_class._optional_components } init_kwargs = {**init_kwargs, **passed_pipe_kwargs} # remove `null` components def load_module(name, value): if value[0] is None: return False if name in passed_class_obj and passed_class_obj[name] is None: return False return True init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: raise NotImplementedError( "The safety checker cannot be automatically loaded when loading weights `from_flax`." " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker" " separately if you need it." ) # 5. Throw nice warnings / errors for fast accelerate loading if len(unused_kwargs) > 0: logger.warning( f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." ) # import it here to avoid circular import from diffusers import pipelines # 6. device map delegation final_device_map = None if device_map is not None: final_device_map = _get_final_device_map( device_map=device_map, pipeline_class=pipeline_class, passed_class_obj=passed_class_obj, init_dict=init_dict, library=library, max_memory=max_memory, torch_dtype=torch_dtype, cached_folder=cached_folder, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, ) # 7. Load each module in the pipeline current_device_map = None for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): if final_device_map is not None and len(final_device_map) > 0: component_device = final_device_map.get(name, None) if component_device is not None: current_device_map = {"": component_device} else: current_device_map = None # 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names class_name = class_name[4:] if class_name.startswith("Flax") else class_name # 7.2 Define all importable classes is_pipeline_module = hasattr(pipelines, library_name) importable_classes = ALL_IMPORTABLE_CLASSES loaded_sub_model = None # 7.3 Use passed sub model or load class_name from library_name if name in passed_class_obj: # if the model is in a pipeline module, then we load it from the pipeline # check that passed_class_obj has correct parent class maybe_raise_or_warn( library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module ) loaded_sub_model = passed_class_obj[name] else: # load sub model loaded_sub_model = load_sub_model( library_name=library_name, class_name=class_name, importable_classes=importable_classes, pipelines=pipelines, is_pipeline_module=is_pipeline_module, pipeline_class=pipeline_class, torch_dtype=torch_dtype, provider=provider, sess_options=sess_options, device_map=current_device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, model_variants=model_variants, name=name, from_flax=from_flax, variant=variant, low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, ) logger.info( f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." ) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")): modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS} load_kwargs = { "cache_dir": cache_dir, "force_download": force_download, "proxies": proxies, "local_files_only": local_files_only, "token": token, "revision": revision, "torch_dtype": torch_dtype, "custom_pipeline": custom_pipeline, "custom_revision": custom_revision, "provider": provider, "sess_options": sess_options, "device_map": device_map, "max_memory": max_memory, "offload_folder": offload_folder, "offload_state_dict": offload_state_dict, "low_cpu_mem_usage": low_cpu_mem_usage, "variant": variant, "use_safetensors": use_safetensors, } def get_connected_passed_kwargs(prefix): connected_passed_class_obj = { k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix } connected_passed_pipe_kwargs = { k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix } connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs} return connected_passed_kwargs connected_pipes = { prefix: DiffusionPipeline.from_pretrained( repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix) ) for prefix, repo_id in connected_pipes.items() if repo_id is not None } for prefix, connected_pipe in connected_pipes.items(): # add connected pipes to `init_kwargs` with _, e.g. "prior_text_encoder" init_kwargs.update( {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()} ) # 8. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) optional_modules = pipeline_class._optional_components if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): for module in missing_modules: init_kwargs[module] = passed_class_obj.get(module, None) elif len(missing_modules) > 0: passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs raise ValueError( f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) # 10. Instantiate the pipeline model = pipeline_class(**init_kwargs) # 11. Save where the model was instantiated from model.register_to_config(_name_or_path=pretrained_model_name_or_path) if device_map is not None: setattr(model, "hf_device_map", final_device_map) return model @property def name_or_path(self) -> str: return getattr(self.config, "_name_or_path", None) @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from Accelerate's module hooks. """ for name, model in self.components.items(): if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: continue if not hasattr(model, "_hf_hook"): return self.device for module in model.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def remove_all_hooks(self): r""" Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`. """ for _, model in self.components.items(): if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"): accelerate.hooks.remove_hook_from_module(model, recurse=True) self._all_hooks = [] def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. Arguments: gpu_id (`int`, *optional*): The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. device (`torch.Device` or `str`, *optional*, defaults to "cuda"): The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`." ) if self.model_cpu_offload_seq is None: raise ValueError( "Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set." ) if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") self.remove_all_hooks() torch_device = torch.device(device) device_index = torch_device.index if gpu_id is not None and device_index is not None: raise ValueError( f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" ) # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0 self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0) device_type = torch_device.type device = torch.device(f"{device_type}:{self._offload_gpu_id}") self._offload_device = device self.to("cpu", silence_dtype_warnings=True) device_mod = getattr(torch, device.type, None) if hasattr(device_mod, "empty_cache") and device_mod.is_available(): device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} self._all_hooks = [] hook = None for model_str in self.model_cpu_offload_seq.split("->"): model = all_model_components.pop(model_str, None) if not isinstance(model, torch.nn.Module): continue _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook) self._all_hooks.append(hook) # CPU offload models that are not in the seq chain unless they are explicitly excluded # these models will stay on CPU until maybe_free_model_hooks is called # some models cannot be in the seq chain because they are iteratively called, such as controlnet for name, model in all_model_components.items(): if not isinstance(model, torch.nn.Module): continue if name in self._exclude_from_cpu_offload: model.to(device) else: _, hook = cpu_offload_with_hook(model, device) self._all_hooks.append(hook) def maybe_free_model_hooks(self): r""" Function that offloads all components, removes all model hooks that were added when using `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions correctly when applying enable_model_cpu_offload. """ if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing return # make sure the model is in the same state as before calling it self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda")) def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward` method called. Offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. Arguments: gpu_id (`int`, *optional*): The ID of the accelerator that shall be used in inference. If not specified, it will default to 0. device (`torch.Device` or `str`, *optional*, defaults to "cuda"): The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") self.remove_all_hooks() is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`." ) torch_device = torch.device(device) device_index = torch_device.index if gpu_id is not None and device_index is not None: raise ValueError( f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" ) # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0 self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0) device_type = torch_device.type device = torch.device(f"{device_type}:{self._offload_gpu_id}") self._offload_device = device if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) device_mod = getattr(torch, self.device.type, None) if hasattr(device_mod, "empty_cache") and device_mod.is_available(): device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for name, model in self.components.items(): if not isinstance(model, torch.nn.Module): continue if name in self._exclude_from_cpu_offload: model.to(device) else: # make sure to offload buffers if not all high level weights # are of type nn.Module offload_buffers = len(model._parameters) > 0 cpu_offload(model, device, offload_buffers=offload_buffers) def reset_device_map(self): r""" Resets the device maps (if any) to None. """ if self.hf_device_map is None: return else: self.remove_all_hooks() for name, component in self.components.items(): if isinstance(component, torch.nn.Module): component.to("cpu") self.hf_device_map = None @classmethod @validate_hf_hub_args def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: r""" Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights. Parameters: pretrained_model_name (`str` or `os.PathLike`, *optional*): A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. custom_pipeline (`str`, *optional*): Can be either: - A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. The repository must contain a file called `pipeline.py` that defines the custom pipeline. - A string, the *file name* of a community pipeline hosted on GitHub under [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file names must match the file name and not the pipeline script (`clip_guided_stable_diffusion` instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the current `main` branch of GitHub. - A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory must contain a file called `pipeline.py` that defines the custom pipeline. 🧪 This is an experimental feature and may change in the future. For more information on how to load and create custom pipelines, take a look at [How to contribute a community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline). force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. custom_revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id similar to `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. use_onnx (`bool`, *optional*, defaults to `False`): If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending with `.onnx` and `.pb`. trust_remote_code (`bool`, *optional*, defaults to `False`): Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. Returns: `os.PathLike`: A path to the downloaded pipeline. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) from_flax = kwargs.pop("from_flax", False) custom_pipeline = kwargs.pop("custom_pipeline", None) custom_revision = kwargs.pop("custom_revision", None) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) trust_remote_code = kwargs.pop("trust_remote_code", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True allow_patterns = None ignore_patterns = None model_info_call_error: Optional[Exception] = None if not local_files_only: try: info = model_info(pretrained_model_name, token=token, revision=revision) except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e: logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") local_files_only = True model_info_call_error = e # save error to reraise it if model is not cached locally if not local_files_only: config_file = hf_hub_download( pretrained_model_name, cls.config_name, cache_dir=cache_dir, revision=revision, proxies=proxies, force_download=force_download, token=token, ) config_dict = cls._dict_from_json_file(config_file) ignore_filenames = config_dict.pop("_ignore_files", []) # retrieve all folder_names that contain relevant files folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] filenames = {sibling.rfilename for sibling in info.siblings} model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) diffusers_module = importlib.import_module(__name__.split(".")[0]) pipelines = getattr(diffusers_module, "pipelines") # optionally create a custom component <> custom file mapping custom_components = {} for component in folder_names: module_candidate = config_dict[component][0] if module_candidate is None or not isinstance(module_candidate, str): continue # We compute candidate file path on the Hub. Do not use `os.path.join`. candidate_file = f"{component}/{module_candidate}.py" if candidate_file in filenames: custom_components[component] = module_candidate elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): raise ValueError( f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." ) if len(variant_filenames) == 0 and variant is not None: deprecation_message = ( f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant" "modeling files is deprecated." ) deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False) # remove ignored filenames model_filenames = set(model_filenames) - set(ignore_filenames) variant_filenames = set(variant_filenames) - set(ignore_filenames) # if the whole pipeline is cached we don't have to ping the Hub if revision in DEPRECATED_REVISION_ARGS and version.parse( version.parse(__version__).base_version ) >= version.parse("0.22.0"): warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames) model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} custom_class_name = None if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)): custom_pipeline = config_dict["_class_name"][0] custom_class_name = config_dict["_class_name"][1] # all filenames compatible with variant will be added allow_patterns = list(model_filenames) # allow all patterns from non-model folders # this enables downloading schedulers, tokenizers, ... allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] # add custom component files allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()] # add custom pipeline file allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] # also allow downloading config.json files with the model allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] allow_patterns += [ SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name, CUSTOM_PIPELINE_FILE_NAME, ] load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames load_components_from_hub = len(custom_components) > 0 if load_pipe_from_hub and not trust_remote_code: raise ValueError( f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly " f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n" f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." ) if load_components_from_hub and not trust_remote_code: raise ValueError( f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly " f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n" f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." ) # retrieve passed components that should not be downloaded pipeline_class = _get_pipeline_class( cls, config_dict, load_connected_pipeline=load_connected_pipeline, custom_pipeline=custom_pipeline, repo_id=pretrained_model_name if load_pipe_from_hub else None, hub_revision=revision, class_name=custom_class_name, cache_dir=cache_dir, revision=custom_revision, ) expected_components, _ = cls._get_signature_keys(pipeline_class) passed_components = [k for k in expected_components if k in kwargs] if ( use_safetensors and not allow_pickle and not is_safetensors_compatible( model_filenames, variant=variant, passed_components=passed_components ) ): raise EnvironmentError( f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" ) if from_flax: ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] elif use_safetensors and is_safetensors_compatible( model_filenames, variant=variant, passed_components=passed_components ): ignore_patterns = ["*.bin", "*.msgpack"] use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx if not use_onnx: ignore_patterns += ["*.onnx", "*.pb"] safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} if ( len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames ): logger.warning( f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." ) else: ignore_patterns = ["*.safetensors", "*.msgpack"] use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx if not use_onnx: ignore_patterns += ["*.onnx", "*.pb"] bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: logger.warning( f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." ) # Don't download any objects that are passed allow_patterns = [ p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) ] if pipeline_class._load_connected_pipes: allow_patterns.append("README.md") # Don't download index files of forbidden patterns either ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns] re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns] re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns] expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)] expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)] snapshot_folder = Path(config_file).parent pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files) if pipeline_is_cached and not force_download: # if the pipeline is cached, we can directly return it # else call snapshot_download return snapshot_folder user_agent = {"pipeline_class": cls.__name__} if custom_pipeline is not None and not custom_pipeline.endswith(".py"): user_agent["custom_pipeline"] = custom_pipeline # download all allow_patterns - ignore_patterns try: cached_folder = snapshot_download( pretrained_model_name, cache_dir=cache_dir, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, user_agent=user_agent, ) # retrieve pipeline class from local file cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None) cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name diffusers_module = importlib.import_module(__name__.split(".")[0]) pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None if pipeline_class is not None and pipeline_class._load_connected_pipes: modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], []) for connected_pipe_repo_id in connected_pipes: download_kwargs = { "cache_dir": cache_dir, "force_download": force_download, "proxies": proxies, "local_files_only": local_files_only, "token": token, "variant": variant, "use_safetensors": use_safetensors, } DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs) return cached_folder except FileNotFoundError: # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache. # This can happen in two cases: # 1. If the user passed `local_files_only=True` => we raise the error directly # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error if model_info_call_error is None: # 1. user passed `local_files_only=True` raise else: # 2. we forced `local_files_only=True` when `model_info` failed raise EnvironmentError( f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred" " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace" " above." ) from model_info_call_error @classmethod def _get_signature_keys(cls, obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) expected_modules = set(required_parameters.keys()) - {"self"} optional_names = list(optional_parameters) for name in optional_names: if name in cls._optional_components: expected_modules.add(name) optional_parameters.remove(name) return expected_modules, optional_parameters @classmethod def _get_signature_types(cls): signature_types = {} for k, v in inspect.signature(cls.__init__).parameters.items(): if inspect.isclass(v.annotation): signature_types[k] = (v.annotation,) elif get_origin(v.annotation) == Union: signature_types[k] = get_args(v.annotation) else: logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.") return signature_types @property def components(self) -> Dict[str, Any]: r""" The `self.components` property can be useful to run different pipelines with the same weights and configurations without reallocating additional memory. Returns (`dict`): A dictionary containing all the modules needed to initialize the pipeline. Examples: ```py >>> from diffusers import ( ... StableDiffusionPipeline, ... StableDiffusionImg2ImgPipeline, ... StableDiffusionInpaintPipeline, ... ) >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) ``` """ expected_modules, optional_parameters = self._get_signature_keys(self) components = { k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters } if set(components.keys()) != expected_modules: raise ValueError( f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" f" {expected_modules} to be defined, but {components.keys()} are defined." ) return components @staticmethod def numpy_to_pil(images): """ Convert a NumPy image or a batch of images to a PIL image. """ return numpy_to_pil(images) def progress_bar(self, iterable=None, total=None): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): raise ValueError( f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) if iterable is not None: return tqdm(iterable, **self._progress_bar_config) elif total is not None: return tqdm(total=total, **self._progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): r""" Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed. ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent. Parameters: attention_op (`Callable`, *optional*): Override the default `None` operator for use as `op` argument to the [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) function of xFormers. Examples: ```py >>> import torch >>> from diffusers import DiffusionPipeline >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) >>> # Workaround for not accepting attention shape using VAE for Flash Attention >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None) ``` """ self.set_use_memory_efficient_attention_xformers(True, attention_op) def disable_xformers_memory_efficient_attention(self): r""" Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). """ self.set_use_memory_efficient_attention_xformers(False) def set_use_memory_efficient_attention_xformers( self, valid: bool, attention_op: Optional[Callable] = None ) -> None: # Recursively walk through all the children. # Any children which exposes the set_use_memory_efficient_attention_xformers method # gets the message def fn_recursive_set_mem_eff(module: torch.nn.Module): if hasattr(module, "set_use_memory_efficient_attention_xformers"): module.set_use_memory_efficient_attention_xformers(valid, attention_op) for child in module.children(): fn_recursive_set_mem_eff(child) module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] for module in modules: fn_recursive_set_mem_eff(module) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. For more than one attention head, the computation is performed sequentially over each head. This is useful to save some memory in exchange for a small speed decrease. ⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. These attention computations are already very memory efficient so you won't need to enable this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs! Args: slice_size (`str` or `int`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPipeline >>> pipe = StableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", ... torch_dtype=torch.float16, ... use_safetensors=True, ... ) >>> prompt = "a photo of an astronaut riding a horse on mars" >>> pipe.enable_attention_slicing() >>> image = pipe(prompt).images[0] ``` """ self.set_attention_slice(slice_size) def disable_attention_slicing(self): r""" Disable sliced attention computation. If `enable_attention_slicing` was previously called, attention is computed in one step. """ # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) def set_attention_slice(self, slice_size: Optional[int]): module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")] for module in modules: module.set_attention_slice(slice_size) @classmethod def from_pipe(cls, pipeline, **kwargs): r""" Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing pipeline components without reallocating additional memory. Arguments: pipeline (`DiffusionPipeline`): The pipeline from which to create a new pipeline. Returns: `DiffusionPipeline`: A new pipeline with the same weights and configurations as `pipeline`. Examples: ```py >>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") >>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe) ``` """ original_config = dict(pipeline.config) torch_dtype = kwargs.pop("torch_dtype", None) # derive the pipeline class to instantiate custom_pipeline = kwargs.pop("custom_pipeline", None) custom_revision = kwargs.pop("custom_revision", None) if custom_pipeline is not None: pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision) else: pipeline_class = cls expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) # true_optional_modules are optional components with default value in signature so it is ok not to pass them to `__init__` # e.g. `image_encoder` for StableDiffusionPipeline parameters = inspect.signature(cls.__init__).parameters true_optional_modules = set( {k for k, v in parameters.items() if v.default != inspect._empty and k in expected_modules} ) # get the class of each component based on its type hint # e.g. {"unet": UNet2DConditionModel, "text_encoder": CLIPTextMode} component_types = pipeline_class._get_signature_types() pretrained_model_name_or_path = original_config.pop("_name_or_path", None) # allow users pass modules in `kwargs` to override the original pipeline's components passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} original_class_obj = {} for name, component in pipeline.components.items(): if name in expected_modules and name not in passed_class_obj: # for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature if ( not isinstance(component, ModelMixin) or type(component) in component_types[name] or (component is None and name in cls._optional_components) ): original_class_obj[name] = component else: logger.warning( f"component {name} is not switched over to new pipeline because type does not match the expected." f" {name} is {type(component)} while the new pipeline expect {component_types[name]}." f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`" ) # allow users pass optional kwargs to override the original pipelines config attribute passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} original_pipe_kwargs = { k: original_config[k] for k in original_config.keys() if k in optional_kwargs and k not in passed_pipe_kwargs } # config attribute that were not expected by pipeline is stored as its private attribute # (i.e. when the original pipeline was also instantiated with `from_pipe` from another pipeline that has this config) # in this case, we will pass them as optional arguments if they can be accepted by the new pipeline additional_pipe_kwargs = [ k[1:] for k in original_config.keys() if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs ] for k in additional_pipe_kwargs: original_pipe_kwargs[k] = original_config.pop(f"_{k}") pipeline_kwargs = { **passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs, **kwargs, } # store unused config as private attribute in the new pipeline unused_original_config = { f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs } missing_modules = ( set(expected_modules) - set(pipeline._optional_components) - set(pipeline_kwargs.keys()) - set(true_optional_modules) ) if len(missing_modules) > 0: raise ValueError( f"Pipeline {pipeline_class} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed" ) new_pipeline = pipeline_class(**pipeline_kwargs) if pretrained_model_name_or_path is not None: new_pipeline.register_to_config(_name_or_path=pretrained_model_name_or_path) new_pipeline.register_to_config(**unused_original_config) if torch_dtype is not None: new_pipeline.to(dtype=torch_dtype) return new_pipeline class StableDiffusionMixin: r""" Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion) """ def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stages where they are being applied. Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. Args: s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate "oversmoothing effect" in the enhanced denoising process. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate "oversmoothing effect" in the enhanced denoising process. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ if not hasattr(self, "unet"): raise ValueError("The pipeline must have `unet` for using FreeU.") self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. Args: unet (`bool`, defaults to `True`): To apply fusion on the UNet. vae (`bool`, defaults to `True`): To apply fusion on the VAE. """ self.fusing_unet = False self.fusing_vae = False if unet: self.fusing_unet = True self.unet.fuse_qkv_projections() self.unet.set_attn_processor(FusedAttnProcessor2_0()) if vae: if not isinstance(self.vae, AutoencoderKL): raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") self.fusing_vae = True self.vae.fuse_qkv_projections() self.vae.set_attn_processor(FusedAttnProcessor2_0()) def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): """Disable QKV projection fusion if enabled. This API is 🧪 experimental. Args: unet (`bool`, defaults to `True`): To apply fusion on the UNet. vae (`bool`, defaults to `True`): To apply fusion on the VAE. """ if unet: if not self.fusing_unet: logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") else: self.unet.unfuse_qkv_projections() self.fusing_unet = False if vae: if not self.fusing_vae: logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") else: self.vae.unfuse_qkv_projections() self.fusing_vae = False ================================================ FILE: src/diffusers/pipelines/pixart_alpha/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"] _import_structure["pipeline_pixart_sigma"] = ["PixArtSigmaPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_pixart_alpha import ( ASPECT_RATIO_256_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_1024_BIN, PixArtAlphaPipeline, ) from .pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN, PixArtSigmaPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py ================================================ # Copyright 2024 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import html import inspect import re import urllib.parse as ul from typing import Callable, List, Optional, Tuple, Union import torch from transformers import T5EncoderModel, T5Tokenizer from ...image_processor import PixArtImageProcessor from ...models import AutoencoderKL, PixArtTransformer2DModel from ...schedulers import DPMSolverMultistepScheduler from ...utils import ( BACKENDS_MAPPING, deprecate, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import PixArtAlphaPipeline >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too. >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) >>> # Enable memory optimizations. >>> pipe.enable_model_cpu_offload() >>> prompt = "A small cactus with a happy face in the Sahara desert." >>> image = pipe(prompt).images[0] ``` """ ASPECT_RATIO_1024_BIN = { "0.25": [512.0, 2048.0], "0.28": [512.0, 1856.0], "0.32": [576.0, 1792.0], "0.33": [576.0, 1728.0], "0.35": [576.0, 1664.0], "0.4": [640.0, 1600.0], "0.42": [640.0, 1536.0], "0.48": [704.0, 1472.0], "0.5": [704.0, 1408.0], "0.52": [704.0, 1344.0], "0.57": [768.0, 1344.0], "0.6": [768.0, 1280.0], "0.68": [832.0, 1216.0], "0.72": [832.0, 1152.0], "0.78": [896.0, 1152.0], "0.82": [896.0, 1088.0], "0.88": [960.0, 1088.0], "0.94": [960.0, 1024.0], "1.0": [1024.0, 1024.0], "1.07": [1024.0, 960.0], "1.13": [1088.0, 960.0], "1.21": [1088.0, 896.0], "1.29": [1152.0, 896.0], "1.38": [1152.0, 832.0], "1.46": [1216.0, 832.0], "1.67": [1280.0, 768.0], "1.75": [1344.0, 768.0], "2.0": [1408.0, 704.0], "2.09": [1472.0, 704.0], "2.4": [1536.0, 640.0], "2.5": [1600.0, 640.0], "3.0": [1728.0, 576.0], "4.0": [2048.0, 512.0], } ASPECT_RATIO_512_BIN = { "0.25": [256.0, 1024.0], "0.28": [256.0, 928.0], "0.32": [288.0, 896.0], "0.33": [288.0, 864.0], "0.35": [288.0, 832.0], "0.4": [320.0, 800.0], "0.42": [320.0, 768.0], "0.48": [352.0, 736.0], "0.5": [352.0, 704.0], "0.52": [352.0, 672.0], "0.57": [384.0, 672.0], "0.6": [384.0, 640.0], "0.68": [416.0, 608.0], "0.72": [416.0, 576.0], "0.78": [448.0, 576.0], "0.82": [448.0, 544.0], "0.88": [480.0, 544.0], "0.94": [480.0, 512.0], "1.0": [512.0, 512.0], "1.07": [512.0, 480.0], "1.13": [544.0, 480.0], "1.21": [544.0, 448.0], "1.29": [576.0, 448.0], "1.38": [576.0, 416.0], "1.46": [608.0, 416.0], "1.67": [640.0, 384.0], "1.75": [672.0, 384.0], "2.0": [704.0, 352.0], "2.09": [736.0, 352.0], "2.4": [768.0, 320.0], "2.5": [800.0, 320.0], "3.0": [864.0, 288.0], "4.0": [1024.0, 256.0], } ASPECT_RATIO_256_BIN = { "0.25": [128.0, 512.0], "0.28": [128.0, 464.0], "0.32": [144.0, 448.0], "0.33": [144.0, 432.0], "0.35": [144.0, 416.0], "0.4": [160.0, 400.0], "0.42": [160.0, 384.0], "0.48": [176.0, 368.0], "0.5": [176.0, 352.0], "0.52": [176.0, 336.0], "0.57": [192.0, 336.0], "0.6": [192.0, 320.0], "0.68": [208.0, 304.0], "0.72": [208.0, 288.0], "0.78": [224.0, 288.0], "0.82": [224.0, 272.0], "0.88": [240.0, 272.0], "0.94": [240.0, 256.0], "1.0": [256.0, 256.0], "1.07": [256.0, 240.0], "1.13": [272.0, 240.0], "1.21": [272.0, 224.0], "1.29": [288.0, 224.0], "1.38": [288.0, 208.0], "1.46": [304.0, 208.0], "1.67": [320.0, 192.0], "1.75": [336.0, 192.0], "2.0": [352.0, 176.0], "2.09": [368.0, 176.0], "2.4": [384.0, 160.0], "2.5": [400.0, 160.0], "3.0": [432.0, 144.0], "4.0": [512.0, 128.0], } # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class PixArtAlphaPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using PixArt-Alpha. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`T5EncoderModel`]): Frozen text-encoder. PixArt-Alpha uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. tokenizer (`T5Tokenizer`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). transformer ([`PixArtTransformer2DModel`]): A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder"] model_cpu_offload_seq = "text_encoder->transformer->vae" def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, vae: AutoencoderKL, transformer: PixArtTransformer2DModel, scheduler: DPMSolverMultistepScheduler, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, negative_prompt: str = "", num_images_per_prompt: int = 1, device: Optional[torch.device] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, clean_caption: bool = False, max_sequence_length: int = 120, **kwargs, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For PixArt-Alpha, this should be "". do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" string. clean_caption (`bool`, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. """ if "mask_feature" in kwargs: deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # See Section 3.1. of the paper. max_length = max_sequence_length if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because T5 can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.to(device) prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.transformer is not None: dtype = self.transformer.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) negative_prompt_attention_mask = uncond_input.attention_mask negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) else: negative_prompt_embeds = None negative_prompt_attention_mask = None return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, negative_prompt, callback_steps, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: raise ValueError( "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" f" {negative_prompt_attention_mask.shape}." ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: str = "", num_inference_steps: int = 20, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, height: Optional[int] = None, width: Optional[int] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, use_resolution_binning: bool = True, max_sequence_length: int = 120, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 4.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size): The width in pixels of the generated image. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. use_resolution_binning (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to the requested resolution. Useful for generating non-square images. max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ if "mask_feature" in kwargs: deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor if use_resolution_binning: if self.transformer.config.sample_size == 128: aspect_ratio_bin = ASPECT_RATIO_1024_BIN elif self.transformer.config.sample_size == 64: aspect_ratio_bin = ASPECT_RATIO_512_BIN elif self.transformer.config.sample_size == 32: aspect_ratio_bin = ASPECT_RATIO_256_BIN else: raise ValueError("Invalid sample size") orig_height, orig_width = height, width height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) self.check_inputs( prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, ) # 2. Default height and width to transformer if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt ( prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( prompt, do_classifier_free_guidance, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, device=device, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, clean_caption=clean_caption, max_sequence_length=max_sequence_length, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Prepare micro-conditions. added_cond_kwargs = {"resolution": None, "aspect_ratio": None} if self.transformer.config.sample_size == 128: resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) if do_classifier_free_guidance: resolution = torch.cat([resolution, resolution], dim=0) aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) current_timestep = t if not torch.is_tensor(current_timestep): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" if isinstance(current_timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML current_timestep = current_timestep.expand(latent_model_input.shape[0]) # predict noise model_output noise_pred = self.transformer( latent_model_input, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, timestep=current_timestep, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # learned sigma if self.transformer.config.out_channels // 2 == latent_channels: noise_pred = noise_pred.chunk(2, dim=1)[0] else: noise_pred = noise_pred # compute previous image: x_t -> x_t-1 if num_inference_steps == 1: # For DMD one step sampling: https://arxiv.org/abs/2311.18828 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) else: image = latents if not output_type == "latent": image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py ================================================ # Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import html import inspect import re import urllib.parse as ul from typing import Callable, List, Optional, Tuple, Union import torch from transformers import T5EncoderModel, T5Tokenizer from ...image_processor import PixArtImageProcessor from ...models import AutoencoderKL, PixArtTransformer2DModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( BACKENDS_MAPPING, deprecate, is_bs4_available, is_ftfy_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .pipeline_pixart_alpha import ( ASPECT_RATIO_256_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_1024_BIN, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy ASPECT_RATIO_2048_BIN = { "0.25": [1024.0, 4096.0], "0.26": [1024.0, 3968.0], "0.27": [1024.0, 3840.0], "0.28": [1024.0, 3712.0], "0.32": [1152.0, 3584.0], "0.33": [1152.0, 3456.0], "0.35": [1152.0, 3328.0], "0.4": [1280.0, 3200.0], "0.42": [1280.0, 3072.0], "0.48": [1408.0, 2944.0], "0.5": [1408.0, 2816.0], "0.52": [1408.0, 2688.0], "0.57": [1536.0, 2688.0], "0.6": [1536.0, 2560.0], "0.68": [1664.0, 2432.0], "0.72": [1664.0, 2304.0], "0.78": [1792.0, 2304.0], "0.82": [1792.0, 2176.0], "0.88": [1920.0, 2176.0], "0.94": [1920.0, 2048.0], "1.0": [2048.0, 2048.0], "1.07": [2048.0, 1920.0], "1.13": [2176.0, 1920.0], "1.21": [2176.0, 1792.0], "1.29": [2304.0, 1792.0], "1.38": [2304.0, 1664.0], "1.46": [2432.0, 1664.0], "1.67": [2560.0, 1536.0], "1.75": [2688.0, 1536.0], "2.0": [2816.0, 1408.0], "2.09": [2944.0, 1408.0], "2.4": [3072.0, 1280.0], "2.5": [3200.0, 1280.0], "2.89": [3328.0, 1152.0], "3.0": [3456.0, 1152.0], "3.11": [3584.0, 1152.0], "3.62": [3712.0, 1024.0], "3.75": [3840.0, 1024.0], "3.88": [3968.0, 1024.0], "4.0": [4096.0, 1024.0], } EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import PixArtSigmaPipeline >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-Sigma-XL-2-512-MS" too. >>> pipe = PixArtSigmaPipeline.from_pretrained( ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16 ... ) >>> # Enable memory optimizations. >>> # pipe.enable_model_cpu_offload() >>> prompt = "A small cactus with a happy face in the Sahara desert." >>> image = pipe(prompt).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class PixArtSigmaPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using PixArt-Sigma. """ bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder"] model_cpu_offload_seq = "text_encoder->transformer->vae" def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, vae: AutoencoderKL, transformer: PixArtTransformer2DModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 def encode_prompt( self, prompt: Union[str, List[str]], do_classifier_free_guidance: bool = True, negative_prompt: str = "", num_images_per_prompt: int = 1, device: Optional[torch.device] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, clean_caption: bool = False, max_sequence_length: int = 300, **kwargs, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For PixArt-Alpha, this should be "". do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt device: (`torch.device`, *optional*): torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" string. clean_caption (`bool`, defaults to `False`): If `True`, the function will preprocess and clean the provided caption before encoding. max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. """ if "mask_feature" in kwargs: deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # See Section 3.1. of the paper. max_length = max_sequence_length if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because T5 can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.to(device) prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.transformer is not None: dtype = self.transformer.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) negative_prompt_attention_mask = uncond_input.attention_mask negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) else: negative_prompt_embeds = None negative_prompt_attention_mask = None return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs def check_inputs( self, prompt, height, width, negative_prompt, callback_steps, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: raise ValueError( "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" f" {negative_prompt_attention_mask.shape}." ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warning("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: str = "", num_inference_steps: int = 20, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 4.5, num_images_per_prompt: Optional[int] = 1, height: Optional[int] = None, width: Optional[int] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, use_resolution_binning: bool = True, max_sequence_length: int = 300, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 4.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size): The width in pixels of the generated image. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. use_resolution_binning (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to the requested resolution. Useful for generating non-square images. max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor if use_resolution_binning: if self.transformer.config.sample_size == 256: aspect_ratio_bin = ASPECT_RATIO_2048_BIN elif self.transformer.config.sample_size == 128: aspect_ratio_bin = ASPECT_RATIO_1024_BIN elif self.transformer.config.sample_size == 64: aspect_ratio_bin = ASPECT_RATIO_512_BIN elif self.transformer.config.sample_size == 32: aspect_ratio_bin = ASPECT_RATIO_256_BIN else: raise ValueError("Invalid sample size") orig_height, orig_width = height, width height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) self.check_inputs( prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, ) # 2. Default height and width to transformer if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt ( prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( prompt, do_classifier_free_guidance, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, device=device, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, clean_caption=clean_caption, max_sequence_length=max_sequence_length, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Prepare micro-conditions. added_cond_kwargs = {"resolution": None, "aspect_ratio": None} # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) current_timestep = t if not torch.is_tensor(current_timestep): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" if isinstance(current_timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML current_timestep = current_timestep.expand(latent_model_input.shape[0]) # predict noise model_output noise_pred = self.transformer( latent_model_input, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, timestep=current_timestep, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # learned sigma if self.transformer.config.out_channels // 2 == latent_channels: noise_pred = noise_pred.chunk(2, dim=1)[0] else: noise_pred = noise_pred # compute previous image: x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) else: image = latents if not output_type == "latent": image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/semantic_stable_diffusion/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_output"] = ["SemanticStableDiffusionPipelineOutput"] _import_structure["pipeline_semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/semantic_stable_diffusion/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL.Image from ...utils import BaseOutput @dataclass class SemanticStableDiffusionPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. nsfw_content_detected (`List[bool]`) List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] ================================================ FILE: src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py ================================================ import inspect from itertools import repeat from typing import Callable, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import SemanticStableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with latent editing. This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`Q16SafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, editing_prompt: Optional[Union[str, List[str]]] = None, editing_prompt_embeddings: Optional[torch.Tensor] = None, reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, edit_guidance_scale: Optional[Union[float, List[float]]] = 5, edit_warmup_steps: Optional[Union[int, List[int]]] = 10, edit_cooldown_steps: Optional[Union[int, List[int]]] = None, edit_threshold: Optional[Union[float, List[float]]] = 0.9, edit_momentum_scale: Optional[float] = 0.1, edit_mom_beta: Optional[float] = 0.4, edit_weights: Optional[List[float]] = None, sem_guidance: Optional[List[torch.Tensor]] = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide image generation. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. editing_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to use for semantic guidance. Semantic guidance is disabled by setting `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`. editing_prompt_embeddings (`torch.Tensor`, *optional*): Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be specified via `reverse_editing_direction`. reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): Whether the corresponding prompt in `editing_prompt` should be increased or decreased. edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): Guidance scale for semantic guidance. If provided as a list, values should correspond to `editing_prompt`. edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): Number of diffusion steps (for each prompt) for which semantic guidance is not applied. Momentum is calculated for those steps and applied once all warmup periods are over. edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): Number of diffusion steps (for each prompt) after which semantic guidance is longer applied. edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): Threshold of semantic guidance. edit_momentum_scale (`float`, *optional*, defaults to 0.1): Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0, momentum is disabled. Momentum is already built up during warmup (for diffusion steps smaller than `sld_warmup_steps`). Momentum is only added to latent guidance once all warmup periods are finished. edit_mom_beta (`float`, *optional*, defaults to 0.4): Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous momentum is kept. Momentum is already built up during warmup (for diffusion steps smaller than `edit_warmup_steps`). edit_weights (`List[float]`, *optional*, defaults to `None`): Indicates how much each individual concept should influence the overall guidance. If no weights are provided all concepts are applied equally. sem_guidance (`List[torch.Tensor]`, *optional*): List of pre-generated guidance vectors to be applied at generation. Length of the list has to correspond to `num_inference_steps`. Examples: ```py >>> import torch >>> from diffusers import SemanticStableDiffusionPipeline >>> pipe = SemanticStableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> out = pipe( ... prompt="a photo of the face of a woman", ... num_images_per_prompt=1, ... guidance_scale=7, ... editing_prompt=[ ... "smiling, smile", # Concepts to apply ... "glasses, wearing glasses", ... "curls, wavy hair, curly hair", ... "beard, full beard, mustache", ... ], ... reverse_editing_direction=[ ... False, ... False, ... False, ... False, ... ], # Direction of guidance i.e. increase all concepts ... edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept ... edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept ... edit_threshold=[ ... 0.99, ... 0.975, ... 0.925, ... 0.96, ... ], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions ... edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance ... edit_mom_beta=0.6, # Momentum beta ... edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other ... ) >>> image = out.images[0] ``` Returns: [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device if editing_prompt: enable_edit_guidance = True if isinstance(editing_prompt, str): editing_prompt = [editing_prompt] enabled_editing_prompts = len(editing_prompt) elif editing_prompt_embeddings is not None: enable_edit_guidance = True enabled_editing_prompts = editing_prompt_embeddings.shape[0] else: enabled_editing_prompts = 0 enable_edit_guidance = False # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) text_input_ids = text_inputs.input_ids if text_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if enable_edit_guidance: # get safety text embeddings if editing_prompt_embeddings is None: edit_concepts_input = self.tokenizer( [x for item in editing_prompt for x in repeat(item, batch_size)], padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) edit_concepts_input_ids = edit_concepts_input.input_ids if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode( edit_concepts_input_ids[:, self.tokenizer.model_max_length :] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] edit_concepts = self.text_encoder(edit_concepts_input_ids.to(device))[0] else: edit_concepts = editing_prompt_embeddings.to(device).repeat(batch_size, 1, 1) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed_edit, seq_len_edit, _ = edit_concepts.shape edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1) edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if enable_edit_guidance: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts]) else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # get the initial random noise unless the user supplied it # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, text_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Initialize edit_momentum to None edit_momentum = None self.uncond_estimates = None self.text_estimates = None self.edit_estimates = None self.sem_guidance = None for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64] noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] noise_pred_edit_concepts = noise_pred_out[2:] # default text guidance noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond) # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0]) if self.uncond_estimates is None: self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape)) self.uncond_estimates[i] = noise_pred_uncond.detach().cpu() if self.text_estimates is None: self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) self.text_estimates[i] = noise_pred_text.detach().cpu() if self.edit_estimates is None and enable_edit_guidance: self.edit_estimates = torch.zeros( (num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) ) if self.sem_guidance is None: self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) if edit_momentum is None: edit_momentum = torch.zeros_like(noise_guidance) if enable_edit_guidance: concept_weights = torch.zeros( (len(noise_pred_edit_concepts), noise_guidance.shape[0]), device=device, dtype=noise_guidance.dtype, ) noise_guidance_edit = torch.zeros( (len(noise_pred_edit_concepts), *noise_guidance.shape), device=device, dtype=noise_guidance.dtype, ) # noise_guidance_edit = torch.zeros_like(noise_guidance) warmup_inds = [] for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): self.edit_estimates[i, c] = noise_pred_edit_concept if isinstance(edit_guidance_scale, list): edit_guidance_scale_c = edit_guidance_scale[c] else: edit_guidance_scale_c = edit_guidance_scale if isinstance(edit_threshold, list): edit_threshold_c = edit_threshold[c] else: edit_threshold_c = edit_threshold if isinstance(reverse_editing_direction, list): reverse_editing_direction_c = reverse_editing_direction[c] else: reverse_editing_direction_c = reverse_editing_direction if edit_weights: edit_weight_c = edit_weights[c] else: edit_weight_c = 1.0 if isinstance(edit_warmup_steps, list): edit_warmup_steps_c = edit_warmup_steps[c] else: edit_warmup_steps_c = edit_warmup_steps if isinstance(edit_cooldown_steps, list): edit_cooldown_steps_c = edit_cooldown_steps[c] elif edit_cooldown_steps is None: edit_cooldown_steps_c = i + 1 else: edit_cooldown_steps_c = edit_cooldown_steps if i >= edit_warmup_steps_c: warmup_inds.append(c) if i >= edit_cooldown_steps_c: noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept) continue noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3)) tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3)) tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) if reverse_editing_direction_c: noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 concept_weights[c, :] = tmp_weights noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c # torch.quantile function expects float32 if noise_guidance_edit_tmp.dtype == torch.float32: tmp = torch.quantile( torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), edit_threshold_c, dim=2, keepdim=False, ) else: tmp = torch.quantile( torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32), edit_threshold_c, dim=2, keepdim=False, ).to(noise_guidance_edit_tmp.dtype) noise_guidance_edit_tmp = torch.where( torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None], noise_guidance_edit_tmp, torch.zeros_like(noise_guidance_edit_tmp), ) noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp warmup_inds = torch.tensor(warmup_inds).to(device) if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: concept_weights = concept_weights.to("cpu") # Offload to cpu noise_guidance_edit = noise_guidance_edit.to("cpu") concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds) concept_weights_tmp = torch.where( concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp ) concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds) noise_guidance_edit_tmp = torch.einsum( "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp ) noise_guidance = noise_guidance + noise_guidance_edit_tmp self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu() del noise_guidance_edit_tmp del concept_weights_tmp concept_weights = concept_weights.to(device) noise_guidance_edit = noise_guidance_edit.to(device) concept_weights = torch.where( concept_weights < 0, torch.zeros_like(concept_weights), concept_weights ) concept_weights = torch.nan_to_num(concept_weights) noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) noise_guidance_edit = noise_guidance_edit.to(edit_momentum.device) noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit if warmup_inds.shape[0] == len(noise_pred_edit_concepts): noise_guidance = noise_guidance + noise_guidance_edit self.sem_guidance[i] = noise_guidance_edit.detach().cpu() if sem_guidance is not None: edit_guidance = sem_guidance[i].to(device) noise_guidance = noise_guidance + edit_guidance noise_pred = noise_pred_uncond + noise_guidance # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 8. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/shap_e/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["camera"] = ["create_pan_cameras"] _import_structure["pipeline_shap_e"] = ["ShapEPipeline"] _import_structure["pipeline_shap_e_img2img"] = ["ShapEImg2ImgPipeline"] _import_structure["renderer"] = [ "BoundingBoxVolume", "ImportanceRaySampler", "MLPNeRFModelOutput", "MLPNeRSTFModel", "ShapEParamsProjModel", "ShapERenderer", "StratifiedRaySampler", "VoidNeRFModel", ] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .camera import create_pan_cameras from .pipeline_shap_e import ShapEPipeline from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline from .renderer import ( BoundingBoxVolume, ImportanceRaySampler, MLPNeRFModelOutput, MLPNeRSTFModel, ShapEParamsProjModel, ShapERenderer, StratifiedRaySampler, VoidNeRFModel, ) else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/shap_e/camera.py ================================================ # Copyright 2024 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Tuple import numpy as np import torch @dataclass class DifferentiableProjectiveCamera: """ Implements a batch, differentiable, standard pinhole camera """ origin: torch.Tensor # [batch_size x 3] x: torch.Tensor # [batch_size x 3] y: torch.Tensor # [batch_size x 3] z: torch.Tensor # [batch_size x 3] width: int height: int x_fov: float y_fov: float shape: Tuple[int] def __post_init__(self): assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2 def resolution(self): return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) def fov(self): return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) def get_image_coords(self) -> torch.Tensor: """ :return: coords of shape (width * height, 2) """ pixel_indices = torch.arange(self.height * self.width) coords = torch.stack( [ pixel_indices % self.width, torch.div(pixel_indices, self.width, rounding_mode="trunc"), ], axis=1, ) return coords @property def camera_rays(self): batch_size, *inner_shape = self.shape inner_batch_size = int(np.prod(inner_shape)) coords = self.get_image_coords() coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape]) rays = self.get_camera_rays(coords) rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3) return rays def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor: batch_size, *shape, n_coords = coords.shape assert n_coords == 2 assert batch_size == self.origin.shape[0] flat = coords.view(batch_size, -1, 2) res = self.resolution() fov = self.fov() fracs = (flat.float() / (res - 1)) * 2 - 1 fracs = fracs * torch.tan(fov / 2) fracs = fracs.view(batch_size, -1, 2) directions = ( self.z.view(batch_size, 1, 3) + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] ) directions = directions / directions.norm(dim=-1, keepdim=True) rays = torch.stack( [ torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]), directions, ], dim=2, ) return rays.view(batch_size, *shape, 2, 3) def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": """ Creates a new camera for the resized view assuming the aspect ratio does not change. """ assert width * self.height == height * self.width, "The aspect ratio should not change." return DifferentiableProjectiveCamera( origin=self.origin, x=self.x, y=self.y, z=self.z, width=width, height=height, x_fov=self.x_fov, y_fov=self.y_fov, ) def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera: origins = [] xs = [] ys = [] zs = [] for theta in np.linspace(0, 2 * np.pi, num=20): z = np.array([np.sin(theta), np.cos(theta), -0.5]) z /= np.sqrt(np.sum(z**2)) origin = -z * 4 x = np.array([np.cos(theta), -np.sin(theta), 0.0]) y = np.cross(z, x) origins.append(origin) xs.append(x) ys.append(y) zs.append(z) return DifferentiableProjectiveCamera( origin=torch.from_numpy(np.stack(origins, axis=0)).float(), x=torch.from_numpy(np.stack(xs, axis=0)).float(), y=torch.from_numpy(np.stack(ys, axis=0)).float(), z=torch.from_numpy(np.stack(zs, axis=0)).float(), width=size, height=size, x_fov=0.7, y_fov=0.7, shape=(1, len(xs)), ) ================================================ FILE: src/diffusers/pipelines/shap_e/pipeline_shap_e.py ================================================ # Copyright 2024 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPTextModelWithProjection, CLIPTokenizer from ...models import PriorTransformer from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .renderer import ShapERenderer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.utils import export_to_gif >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> repo = "openai/shap-e" >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> guidance_scale = 15.0 >>> prompt = "a shark" >>> images = pipe( ... prompt, ... guidance_scale=guidance_scale, ... num_inference_steps=64, ... frame_size=256, ... ).images >>> gif_path = export_to_gif(images[0], "shark_3d.gif") ``` """ @dataclass class ShapEPipelineOutput(BaseOutput): """ Output class for [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`]. Args: images (`torch.Tensor`) A list of images for 3D rendering. """ images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]] class ShapEPipeline(DiffusionPipeline): """ Pipeline for generating latent representation of a 3D asset and rendering with the NeRF method. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. text_encoder ([`~transformers.CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. scheduler ([`HeunDiscreteScheduler`]): A scheduler to be used in combination with the `prior` model to generate image embedding. shap_e_renderer ([`ShapERenderer`]): Shap-E renderer projects the generated latents into parameters of a MLP to create 3D objects with the NeRF rendering method. """ model_cpu_offload_seq = "text_encoder->prior" _exclude_from_cpu_offload = ["shap_e_renderer"] def __init__( self, prior: PriorTransformer, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: HeunDiscreteScheduler, shap_e_renderer: ShapERenderer, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, shap_e_renderer=shap_e_renderer, ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, ): len(prompt) if isinstance(prompt, list) else 1 # YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file self.tokenizer.pad_token_id = 0 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) # in Shap-E it normalize the prompt_embeds and then later rescale it prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(prompt_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # Rescale the features to have unit variance prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds return prompt_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: str, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, guidance_scale: float = 4.0, frame_size: int = 64, output_type: Optional[str] = "pil", # pil, np, latent, mesh return_dict: bool = True, ): """ The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. frame_size (`int`, *optional*, default to 64): The width and height of each image frame of the generated 3D output. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`), `"latent"` (`torch.Tensor`), or mesh ([`MeshDecoderOutput`]). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps num_embeddings = self.prior.config.num_embeddings embedding_dim = self.prior.config.embedding_dim latents = self.prepare_latents( (batch_size, num_embeddings * embedding_dim), prompt_embeds.dtype, device, generator, latents, self.scheduler, ) # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) noise_pred = self.prior( scaled_model_input, timestep=t, proj_embedding=prompt_embeds, ).predicted_image_embedding # remove the variance noise_pred, _ = noise_pred.split( scaled_model_input.shape[2], dim=2 ) # batch_size, num_embeddings, embedding_dim if do_classifier_free_guidance: noise_pred_uncond, noise_pred = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) latents = self.scheduler.step( noise_pred, timestep=t, sample=latents, ).prev_sample # Offload all models self.maybe_free_model_hooks() if output_type not in ["np", "pil", "latent", "mesh"]: raise ValueError( f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}" ) if output_type == "latent": return ShapEPipelineOutput(images=latents) images = [] if output_type == "mesh": for i, latent in enumerate(latents): mesh = self.shap_e_renderer.decode_to_mesh( latent[None, :], device, ) images.append(mesh) else: # np, pil for i, latent in enumerate(latents): image = self.shap_e_renderer.decode_to_image( latent[None, :], device, size=frame_size, ) images.append(image) images = torch.stack(images) images = images.cpu().numpy() if output_type == "pil": images = [self.numpy_to_pil(image) for image in images] if not return_dict: return (images,) return ShapEPipelineOutput(images=images) ================================================ FILE: src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py ================================================ # Copyright 2024 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPVisionModel from ...models import PriorTransformer from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .renderer import ShapERenderer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from PIL import Image >>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.utils import export_to_gif, load_image >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> repo = "openai/shap-e-img2img" >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> guidance_scale = 3.0 >>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png" >>> image = load_image(image_url).convert("RGB") >>> images = pipe( ... image, ... guidance_scale=guidance_scale, ... num_inference_steps=64, ... frame_size=256, ... ).images >>> gif_path = export_to_gif(images[0], "corgi_3d.gif") ``` """ @dataclass class ShapEPipelineOutput(BaseOutput): """ Output class for [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`]. Args: images (`torch.Tensor`) A list of images for 3D rendering. """ images: Union[PIL.Image.Image, np.ndarray] class ShapEImg2ImgPipeline(DiffusionPipeline): """ Pipeline for generating latent representation of a 3D asset and rendering with the NeRF method from an image. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. image_encoder ([`~transformers.CLIPVisionModel`]): Frozen image-encoder. image_processor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to process images. scheduler ([`HeunDiscreteScheduler`]): A scheduler to be used in combination with the `prior` model to generate image embedding. shap_e_renderer ([`ShapERenderer`]): Shap-E renderer projects the generated latents into parameters of a MLP to create 3D objects with the NeRF rendering method. """ model_cpu_offload_seq = "image_encoder->prior" _exclude_from_cpu_offload = ["shap_e_renderer"] def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModel, image_processor: CLIPImageProcessor, scheduler: HeunDiscreteScheduler, shap_e_renderer: ShapERenderer, ): super().__init__() self.register_modules( prior=prior, image_encoder=image_encoder, image_processor=image_processor, scheduler=scheduler, shap_e_renderer=shap_e_renderer, ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_image( self, image, device, num_images_per_prompt, do_classifier_free_guidance, ): if isinstance(image, List) and isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) if not isinstance(image, torch.Tensor): image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0) image = image.to(dtype=self.image_encoder.dtype, device=device) image_embeds = self.image_encoder(image)["last_hidden_state"] image_embeds = image_embeds[:, 1:, :].contiguous() # batch_size, dim, 256 image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: negative_image_embeds = torch.zeros_like(image_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeds = torch.cat([negative_image_embeds, image_embeds]) return image_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image]], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, guidance_scale: float = 4.0, frame_size: int = 64, output_type: Optional[str] = "pil", # pil, np, latent, mesh return_dict: bool = True, ): """ The call function to the pipeline for generation. Args: image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image` or tensor representing an image batch to be used as the starting point. Can also accept image latents as image, but if passing latents directly it is not encoded again. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. frame_size (`int`, *optional*, default to 64): The width and height of each image frame of the generated 3D output. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`), `"latent"` (`torch.Tensor`), or mesh ([`MeshDecoderOutput`]). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, torch.Tensor): batch_size = image.shape[0] elif isinstance(image, list) and isinstance(image[0], (torch.Tensor, PIL.Image.Image)): batch_size = len(image) else: raise ValueError( f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `List[PIL.Image.Image]` or `List[torch.Tensor]` but is {type(image)}" ) device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 image_embeds = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps num_embeddings = self.prior.config.num_embeddings embedding_dim = self.prior.config.embedding_dim if latents is None: latents = self.prepare_latents( (batch_size, num_embeddings * embedding_dim), image_embeds.dtype, device, generator, latents, self.scheduler, ) # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) noise_pred = self.prior( scaled_model_input, timestep=t, proj_embedding=image_embeds, ).predicted_image_embedding # remove the variance noise_pred, _ = noise_pred.split( scaled_model_input.shape[2], dim=2 ) # batch_size, num_embeddings, embedding_dim if do_classifier_free_guidance: noise_pred_uncond, noise_pred = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) latents = self.scheduler.step( noise_pred, timestep=t, sample=latents, ).prev_sample if output_type not in ["np", "pil", "latent", "mesh"]: raise ValueError( f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}" ) # Offload all models self.maybe_free_model_hooks() if output_type == "latent": return ShapEPipelineOutput(images=latents) images = [] if output_type == "mesh": for i, latent in enumerate(latents): mesh = self.shap_e_renderer.decode_to_mesh( latent[None, :], device, ) images.append(mesh) else: # np, pil for i, latent in enumerate(latents): image = self.shap_e_renderer.decode_to_image( latent[None, :], device, size=frame_size, ) images.append(image) images = torch.stack(images) images = images.cpu().numpy() if output_type == "pil": images = [self.numpy_to_pil(image) for image in images] if not return_dict: return (images,) return ShapEPipelineOutput(images=images) ================================================ FILE: src/diffusers/pipelines/shap_e/renderer.py ================================================ # Copyright 2024 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Dict, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...utils import BaseOutput from .camera import create_pan_cameras def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: r""" Sample from the given discrete probability distribution with replacement. The i-th bin is assumed to have mass pmf[i]. Args: pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() n_samples: number of samples Return: indices sampled with replacement """ *shape, support_size, last_dim = pmf.shape assert last_dim == 1 cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: """ Concatenate x and its positional encodings, following NeRF. Reference: https://arxiv.org/pdf/2210.04628.pdf """ if min_deg == max_deg: return x scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device) *shape, dim = x.shape xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) assert xb.shape[-1] == dim * (max_deg - min_deg) emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() return torch.cat([x, emb], dim=-1) def encode_position(position): return posenc_nerf(position, min_deg=0, max_deg=15) def encode_direction(position, direction=None): if direction is None: return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) else: return posenc_nerf(direction, min_deg=0, max_deg=8) def _sanitize_name(x: str) -> str: return x.replace(".", "__") def integrate_samples(volume_range, ts, density, channels): r""" Function integrating the model output. Args: volume_range: Specifies the integral range [t0, t1] ts: timesteps density: torch.Tensor [batch_size, *shape, n_samples, 1] channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] returns: channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume ) """ # 1. Calculate the weights _, _, dt = volume_range.partition(ts) ddensity = density * dt mass = torch.cumsum(ddensity, dim=-2) transmittance = torch.exp(-mass[..., -1, :]) alphas = 1.0 - torch.exp(-ddensity) Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) # This is the probability of light hitting and reflecting off of # something at depth [..., i, :]. weights = alphas * Ts # 2. Integrate channels channels = torch.sum(channels * weights, dim=-2) return channels, weights, transmittance def volume_query_points(volume, grid_size): indices = torch.arange(grid_size**3, device=volume.bbox_min.device) zs = indices % grid_size ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size combined = torch.stack([xs, ys, zs], dim=1) return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min def _convert_srgb_to_linear(u: torch.Tensor): return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4) def _create_flat_edge_indices( flat_cube_indices: torch.Tensor, grid_size: Tuple[int, int, int], ): num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2] y_offset = num_xs num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2] z_offset = num_xs + num_ys return torch.stack( [ # Edges spanning x-axis. flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2], flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + (flat_cube_indices[:, 1] + 1) * grid_size[2] + flat_cube_indices[:, 2], flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] + 1, flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + (flat_cube_indices[:, 1] + 1) * grid_size[2] + flat_cube_indices[:, 2] + 1, # Edges spanning y-axis. ( y_offset + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] ), ( y_offset + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] ), ( y_offset + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] + 1 ), ( y_offset + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + flat_cube_indices[:, 1] * grid_size[2] + flat_cube_indices[:, 2] + 1 ), # Edges spanning z-axis. ( z_offset + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + flat_cube_indices[:, 1] * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ( z_offset + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + flat_cube_indices[:, 1] * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ( z_offset + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ( z_offset + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + flat_cube_indices[:, 2] ), ], dim=-1, ) class VoidNeRFModel(nn.Module): """ Implements the default empty space model where all queries are rendered as background. """ def __init__(self, background, channel_scale=255.0): super().__init__() background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale) self.register_buffer("background", background) def forward(self, position): background = self.background[None].to(position.device) shape = position.shape[:-1] ones = [1] * (len(shape) - 1) n_channels = background.shape[-1] background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]) return background @dataclass class VolumeRange: t0: torch.Tensor t1: torch.Tensor intersected: torch.Tensor def __post_init__(self): assert self.t0.shape == self.t1.shape == self.intersected.shape def partition(self, ts): """ Partitions t0 and t1 into n_samples intervals. Args: ts: [batch_size, *shape, n_samples, 1] Return: lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size, *shape, n_samples, 1] where ts \\in [lower, upper] deltas = upper - lower """ mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 lower = torch.cat([self.t0[..., None, :], mids], dim=-2) upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) delta = upper - lower assert lower.shape == upper.shape == delta.shape == ts.shape return lower, upper, delta class BoundingBoxVolume(nn.Module): """ Axis-aligned bounding box defined by the two opposite corners. """ def __init__( self, *, bbox_min, bbox_max, min_dist: float = 0.0, min_t_range: float = 1e-3, ): """ Args: bbox_min: the left/bottommost corner of the bounding box bbox_max: the other corner of the bounding box min_dist: all rays should start at least this distance away from the origin. """ super().__init__() self.min_dist = min_dist self.min_t_range = min_t_range self.bbox_min = torch.tensor(bbox_min) self.bbox_max = torch.tensor(bbox_max) self.bbox = torch.stack([self.bbox_min, self.bbox_max]) assert self.bbox.shape == (2, 3) assert min_dist >= 0.0 assert min_t_range > 0.0 def intersect( self, origin: torch.Tensor, direction: torch.Tensor, t0_lower: Optional[torch.Tensor] = None, epsilon=1e-6, ): """ Args: origin: [batch_size, *shape, 3] direction: [batch_size, *shape, 3] t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. params: Optional meta parameters in case Volume is parametric epsilon: to stabilize calculations Return: A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to be on the boundary of the volume. """ batch_size, *shape, _ = origin.shape ones = [1] * len(shape) bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device) def _safe_divide(a, b, epsilon=1e-6): return a / torch.where(b < 0, b - epsilon, b + epsilon) ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) # Cases to think about: # # 1. t1 <= t0: the ray does not pass through the AABB. # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. # 3. t0 <= 0 <= t1: the ray starts from inside the BB # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. # # 1 and 4 are clearly handled from t0 < t1 below. # Making t0 at least min_dist (>= 0) takes care of 2 and 3. t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values assert t0.shape == t1.shape == (batch_size, *shape, 1) if t0_lower is not None: assert t0.shape == t0_lower.shape t0 = torch.maximum(t0, t0_lower) intersected = t0 + self.min_t_range < t1 t0 = torch.where(intersected, t0, torch.zeros_like(t0)) t1 = torch.where(intersected, t1, torch.ones_like(t1)) return VolumeRange(t0=t0, t1=t1, intersected=intersected) class StratifiedRaySampler(nn.Module): """ Instead of fixed intervals, a sample is drawn uniformly at random from each interval. """ def __init__(self, depth_mode: str = "linear"): """ :param depth_mode: linear samples ts linearly in depth. harmonic ensures closer points are sampled more densely. """ self.depth_mode = depth_mode assert self.depth_mode in ("linear", "geometric", "harmonic") def sample( self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int, epsilon: float = 1e-3, ) -> torch.Tensor: """ Args: t0: start time has shape [batch_size, *shape, 1] t1: finish time has shape [batch_size, *shape, 1] n_samples: number of ts to sample Return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ ones = [1] * (len(t0.shape) - 1) ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) if self.depth_mode == "linear": ts = t0 * (1.0 - ts) + t1 * ts elif self.depth_mode == "geometric": ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() elif self.depth_mode == "harmonic": # The original NeRF recommends this interpolation scheme for # spherical scenes, but there could be some weird edge cases when # the observer crosses from the inner to outer volume. ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) upper = torch.cat([mids, t1], dim=-1) lower = torch.cat([t0, mids], dim=-1) # yiyi notes: add a random seed here for testing, don't forget to remove torch.manual_seed(0) t_rand = torch.rand_like(ts) ts = lower + (upper - lower) * t_rand return ts.unsqueeze(-1) class ImportanceRaySampler(nn.Module): """ Given the initial estimate of densities, this samples more from regions/bins expected to have objects. """ def __init__( self, volume_range: VolumeRange, ts: torch.Tensor, weights: torch.Tensor, blur_pool: bool = False, alpha: float = 1e-5, ): """ Args: volume_range: the range in which a ray intersects the given volume. ts: earlier samples from the coarse rendering step weights: discretized version of density * transmittance blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. alpha: small value to add to weights. """ self.volume_range = volume_range self.ts = ts.clone().detach() self.weights = weights.clone().detach() self.blur_pool = blur_pool self.alpha = alpha @torch.no_grad() def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: """ Args: t0: start time has shape [batch_size, *shape, 1] t1: finish time has shape [batch_size, *shape, 1] n_samples: number of ts to sample Return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ lower, upper, _ = self.volume_range.partition(self.ts) batch_size, *shape, n_coarse_samples, _ = self.ts.shape weights = self.weights if self.blur_pool: padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) weights = weights + self.alpha pmf = weights / weights.sum(dim=-2, keepdim=True) inds = sample_pmf(pmf, n_samples) assert inds.shape == (batch_size, *shape, n_samples, 1) assert (inds >= 0).all() and (inds < n_coarse_samples).all() t_rand = torch.rand(inds.shape, device=inds.device) lower_ = torch.gather(lower, -2, inds) upper_ = torch.gather(upper, -2, inds) ts = lower_ + (upper_ - lower_) * t_rand ts = torch.sort(ts, dim=-2).values return ts @dataclass class MeshDecoderOutput(BaseOutput): """ A 3D triangle mesh with optional data at the vertices and faces. Args: verts (`torch.Tensor` of shape `(N, 3)`): array of vertext coordinates faces (`torch.Tensor` of shape `(N, 3)`): array of triangles, pointing to indices in verts. vertext_channels (Dict): vertext coordinates for each color channel """ verts: torch.Tensor faces: torch.Tensor vertex_channels: Dict[str, torch.Tensor] class MeshDecoder(nn.Module): """ Construct meshes from Signed distance functions (SDFs) using marching cubes method """ def __init__(self): super().__init__() cases = torch.zeros(256, 5, 3, dtype=torch.long) masks = torch.zeros(256, 5, dtype=torch.bool) self.register_buffer("cases", cases) self.register_buffer("masks", masks) def forward(self, field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor): """ For a signed distance field, produce a mesh using marching cubes. :param field: a 3D tensor of field values, where negative values correspond to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively. :param min_point: a tensor of shape [3] containing the point corresponding to (0, 0, 0) in the field. :param size: a tensor of shape [3] containing the per-axis distance from the (0, 0, 0) field corner and the (-1, -1, -1) field corner. """ assert len(field.shape) == 3, "input must be a 3D scalar field" dev = field.device cases = self.cases.to(dev) masks = self.masks.to(dev) min_point = min_point.to(dev) size = size.to(dev) grid_size = field.shape grid_size_tensor = torch.tensor(grid_size).to(size) # Create bitmasks between 0 and 255 (inclusive) indicating the state # of the eight corners of each cube. bitmasks = (field > 0).to(torch.uint8) bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) # Compute corner coordinates across the entire grid. corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype) corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(grid_size[0], device=dev, dtype=field.dtype)[ :, None, None ] corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(grid_size[1], device=dev, dtype=field.dtype)[ :, None ] corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(grid_size[2], device=dev, dtype=field.dtype) # Compute all vertices across all edges in the grid, even though we will # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. # These are all midpoints, and don't account for interpolation (which is # done later based on the used edge midpoints). edge_midpoints = torch.cat( [ ((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3), ((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3), ((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3), ], dim=0, ) # Create a flat array of [X, Y, Z] indices for each cube. cube_indices = torch.zeros( grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long ) cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[:, None, None] cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[:, None] cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev) flat_cube_indices = cube_indices.reshape(-1, 3) # Create a flat array mapping each cube to 12 global edge indices. edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size) # Apply the LUT to figure out the triangles. flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask local_tris = cases[flat_bitmasks] local_masks = masks[flat_bitmasks] # Compute the global edge indices for the triangles. global_tris = torch.gather(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape( local_tris.shape ) # Select the used triangles for each cube. selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)] # Now we have a bunch of indices into the full list of possible vertices, # but we want to reduce this list to only the used vertices. used_vertex_indices = torch.unique(selected_tris.view(-1)) used_edge_midpoints = edge_midpoints[used_vertex_indices] old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long) old_index_to_new_index[used_vertex_indices] = torch.arange( len(used_vertex_indices), device=dev, dtype=torch.long ) # Rewrite the triangles to use the new indices faces = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(selected_tris.shape) # Compute the actual interpolated coordinates corresponding to edge midpoints. v1 = torch.floor(used_edge_midpoints).to(torch.long) v2 = torch.ceil(used_edge_midpoints).to(torch.long) s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]] s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]] p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point # The signs of s1 and s2 should be different. We want to find # t such that t*s2 + (1-t)*s1 = 0. t = (s1 / (s1 - s2))[:, None] verts = t * p2 + (1 - t) * p1 return MeshDecoderOutput(verts=verts, faces=faces, vertex_channels=None) @dataclass class MLPNeRFModelOutput(BaseOutput): density: torch.Tensor signed_distance: torch.Tensor channels: torch.Tensor ts: torch.Tensor class MLPNeRSTFModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, d_hidden: int = 256, n_output: int = 12, n_hidden_layers: int = 6, act_fn: str = "swish", insert_direction_at: int = 4, ): super().__init__() # Instantiate the MLP # Find out the dimension of encoded position and direction dummy = torch.eye(1, 3) d_posenc_pos = encode_position(position=dummy).shape[-1] d_posenc_dir = encode_direction(position=dummy).shape[-1] mlp_widths = [d_hidden] * n_hidden_layers input_widths = [d_posenc_pos] + mlp_widths output_widths = mlp_widths + [n_output] if insert_direction_at is not None: input_widths[insert_direction_at] += d_posenc_dir self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)]) if act_fn == "swish": # self.activation = swish # yiyi testing: self.activation = lambda x: F.silu(x) else: raise ValueError(f"Unsupported activation function {act_fn}") self.sdf_activation = torch.tanh self.density_activation = torch.nn.functional.relu self.channel_activation = torch.sigmoid def map_indices_to_keys(self, output): h_map = { "sdf": (0, 1), "density_coarse": (1, 2), "density_fine": (2, 3), "stf": (3, 6), "nerf_coarse": (6, 9), "nerf_fine": (9, 12), } mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()} return mapped_output def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"): h = encode_position(position) h_preact = h h_directionless = None for i, layer in enumerate(self.mlp): if i == self.config.insert_direction_at: # 4 in the config h_directionless = h_preact h_direction = encode_direction(position, direction=direction) h = torch.cat([h, h_direction], dim=-1) h = layer(h) h_preact = h if i < len(self.mlp) - 1: h = self.activation(h) h_final = h if h_directionless is None: h_directionless = h_preact activation = self.map_indices_to_keys(h_final) if nerf_level == "coarse": h_density = activation["density_coarse"] else: h_density = activation["density_fine"] if rendering_mode == "nerf": if nerf_level == "coarse": h_channels = activation["nerf_coarse"] else: h_channels = activation["nerf_fine"] elif rendering_mode == "stf": h_channels = activation["stf"] density = self.density_activation(h_density) signed_distance = self.sdf_activation(activation["sdf"]) channels = self.channel_activation(h_channels) # yiyi notes: I think signed_distance is not used return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts) class ChannelsProj(nn.Module): def __init__( self, *, vectors: int, channels: int, d_latent: int, ): super().__init__() self.proj = nn.Linear(d_latent, vectors * channels) self.norm = nn.LayerNorm(channels) self.d_latent = d_latent self.vectors = vectors self.channels = channels def forward(self, x: torch.Tensor) -> torch.Tensor: x_bvd = x w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) b_vc = self.proj.bias.view(1, self.vectors, self.channels) h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) h = self.norm(h) h = h + b_vc return h class ShapEParamsProjModel(ModelMixin, ConfigMixin): """ project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP). For more details, see the original paper: """ @register_to_config def __init__( self, *, param_names: Tuple[str] = ( "nerstf.mlp.0.weight", "nerstf.mlp.1.weight", "nerstf.mlp.2.weight", "nerstf.mlp.3.weight", ), param_shapes: Tuple[Tuple[int]] = ( (256, 93), (256, 256), (256, 256), (256, 256), ), d_latent: int = 1024, ): super().__init__() # check inputs if len(param_names) != len(param_shapes): raise ValueError("Must provide same number of `param_names` as `param_shapes`") self.projections = nn.ModuleDict({}) for k, (vectors, channels) in zip(param_names, param_shapes): self.projections[_sanitize_name(k)] = ChannelsProj( vectors=vectors, channels=channels, d_latent=d_latent, ) def forward(self, x: torch.Tensor): out = {} start = 0 for k, shape in zip(self.config.param_names, self.config.param_shapes): vectors, _ = shape end = start + vectors x_bvd = x[:, start:end] out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) start = end return out class ShapERenderer(ModelMixin, ConfigMixin): @register_to_config def __init__( self, *, param_names: Tuple[str] = ( "nerstf.mlp.0.weight", "nerstf.mlp.1.weight", "nerstf.mlp.2.weight", "nerstf.mlp.3.weight", ), param_shapes: Tuple[Tuple[int]] = ( (256, 93), (256, 256), (256, 256), (256, 256), ), d_latent: int = 1024, d_hidden: int = 256, n_output: int = 12, n_hidden_layers: int = 6, act_fn: str = "swish", insert_direction_at: int = 4, background: Tuple[float] = ( 255.0, 255.0, 255.0, ), ): super().__init__() self.params_proj = ShapEParamsProjModel( param_names=param_names, param_shapes=param_shapes, d_latent=d_latent, ) self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) self.void = VoidNeRFModel(background=background, channel_scale=255.0) self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) self.mesh_decoder = MeshDecoder() @torch.no_grad() def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): """ Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below with some abuse of notations) C(r) := sum( transmittance(t[i]) * integrate( lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], ) for i in range(len(parts)) ) + transmittance(t[-1]) * void_model(t[-1]).channels where 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). Args: rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples: number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including :return: A tuple of - `channels` - A importance samplers for additional fine-grained rendering - raw model output """ origin, direction = rays[..., 0, :], rays[..., 1, :] # Integrate over [t[i], t[i + 1]] # 1 Intersect the rays with the current volume and sample ts to integrate along. vrange = self.volume.intersect(origin, direction, t0_lower=None) ts = sampler.sample(vrange.t0, vrange.t1, n_samples) ts = ts.to(rays.dtype) if prev_model_out is not None: # Append the previous ts now before fprop because previous # rendering used a different model and we can't reuse the output. ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values batch_size, *_shape, _t0_dim = vrange.t0.shape _, *ts_shape, _ts_dim = ts.shape # 2. Get the points along the ray and query the model directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) positions = origin.unsqueeze(-2) + ts * directions directions = directions.to(self.mlp.dtype) positions = positions.to(self.mlp.dtype) optional_directions = directions if render_with_direction else None model_out = self.mlp( position=positions, direction=optional_directions, ts=ts, nerf_level="coarse" if prev_model_out is None else "fine", ) # 3. Integrate the model results channels, weights, transmittance = integrate_samples( vrange, model_out.ts, model_out.density, model_out.channels ) # 4. Clean up results that do not intersect with the volume. transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance)) channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels)) # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). channels = channels + transmittance * self.void(origin) weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights) return channels, weighted_sampler, model_out @torch.no_grad() def decode_to_image( self, latents, device, size: int = 64, ray_batch_size: int = 4096, n_coarse_samples=64, n_fine_samples=128, ): # project the parameters from the generated latents projected_params = self.params_proj(latents) # update the mlp layers of the renderer for name, param in self.mlp.state_dict().items(): if f"nerstf.{name}" in projected_params.keys(): param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) # create cameras object camera = create_pan_cameras(size) rays = camera.camera_rays rays = rays.to(device) n_batches = rays.shape[1] // ray_batch_size coarse_sampler = StratifiedRaySampler() images = [] for idx in range(n_batches): rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] # render rays with coarse, stratified samples. _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples) # Then, render with additional importance-weighted ray samples. channels, _, _ = self.render_rays( rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out ) images.append(channels) images = torch.cat(images, dim=1) images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) return images @torch.no_grad() def decode_to_mesh( self, latents, device, grid_size: int = 128, query_batch_size: int = 4096, texture_channels: Tuple = ("R", "G", "B"), ): # 1. project the parameters from the generated latents projected_params = self.params_proj(latents) # 2. update the mlp layers of the renderer for name, param in self.mlp.state_dict().items(): if f"nerstf.{name}" in projected_params.keys(): param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) # 3. decoding with STF rendering # 3.1 query the SDF values at vertices along a regular 128**3 grid query_points = volume_query_points(self.volume, grid_size) query_positions = query_points[None].repeat(1, 1, 1).to(device=device, dtype=self.mlp.dtype) fields = [] for idx in range(0, query_positions.shape[1], query_batch_size): query_batch = query_positions[:, idx : idx + query_batch_size] model_out = self.mlp( position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" ) fields.append(model_out.signed_distance) # predicted SDF values fields = torch.cat(fields, dim=1) fields = fields.float() assert ( len(fields.shape) == 3 and fields.shape[-1] == 1 ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" fields = fields.reshape(1, *([grid_size] * 3)) # create grid 128 x 128 x 128 # - force a negative border around the SDFs to close off all the models. full_grid = torch.zeros( 1, grid_size + 2, grid_size + 2, grid_size + 2, device=fields.device, dtype=fields.dtype, ) full_grid.fill_(-1.0) full_grid[:, 1:-1, 1:-1, 1:-1] = fields fields = full_grid # apply a differentiable implementation of Marching Cubes to construct meshs raw_meshes = [] mesh_mask = [] for field in fields: raw_mesh = self.mesh_decoder(field, self.volume.bbox_min, self.volume.bbox_max - self.volume.bbox_min) mesh_mask.append(True) raw_meshes.append(raw_mesh) mesh_mask = torch.tensor(mesh_mask, device=fields.device) max_vertices = max(len(m.verts) for m in raw_meshes) # 3.2. query the texture color head at each vertex of the resulting mesh. texture_query_positions = torch.stack( [m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes], dim=0, ) texture_query_positions = texture_query_positions.to(device=device, dtype=self.mlp.dtype) textures = [] for idx in range(0, texture_query_positions.shape[1], query_batch_size): query_batch = texture_query_positions[:, idx : idx + query_batch_size] texture_model_out = self.mlp( position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" ) textures.append(texture_model_out.channels) # predict texture color textures = torch.cat(textures, dim=1) textures = _convert_srgb_to_linear(textures) textures = textures.float() # 3.3 augument the mesh with texture data assert len(textures.shape) == 3 and textures.shape[-1] == len( texture_channels ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" for m, texture in zip(raw_meshes, textures): texture = texture[: len(m.verts)] m.vertex_channels = dict(zip(texture_channels, texture.unbind(-1))) return raw_meshes[0] ================================================ FILE: src/diffusers/pipelines/stable_audio/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, is_transformers_version, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["modeling_stable_audio"] = ["StableAudioProjectionModel"] _import_structure["pipeline_stable_audio"] = ["StableAudioPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .modeling_stable_audio import StableAudioProjectionModel from .pipeline_stable_audio import StableAudioPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_audio/modeling_stable_audio.py ================================================ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from math import pi from typing import Optional import torch import torch.nn as nn import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin from ...utils import BaseOutput, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableAudioPositionalEmbedding(nn.Module): """Used for continuous time""" def __init__(self, dim: int): super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) def forward(self, times: torch.Tensor) -> torch.Tensor: times = times[..., None] freqs = times * self.weights[None] * 2 * pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) fouriered = torch.cat((times, fouriered), dim=-1) return fouriered @dataclass class StableAudioProjectionModelOutput(BaseOutput): """ Args: Class for StableAudio projection layer's outputs. text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): Sequence of hidden-states obtained by linearly projecting the audio start hidden states. seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): Sequence of hidden-states obtained by linearly projecting the audio end hidden states. """ text_hidden_states: Optional[torch.Tensor] = None seconds_start_hidden_states: Optional[torch.Tensor] = None seconds_end_hidden_states: Optional[torch.Tensor] = None class StableAudioNumberConditioner(nn.Module): """ A simple linear projection model to map numbers to a latent space. Args: number_embedding_dim (`int`): Dimensionality of the number embeddings. min_value (`int`): The minimum value of the seconds number conditioning modules. max_value (`int`): The maximum value of the seconds number conditioning modules internal_dim (`int`): Dimensionality of the intermediate number hidden states. """ def __init__( self, number_embedding_dim, min_value, max_value, internal_dim: Optional[int] = 256, ): super().__init__() self.time_positional_embedding = nn.Sequential( StableAudioPositionalEmbedding(internal_dim), nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), ) self.number_embedding_dim = number_embedding_dim self.min_value = min_value self.max_value = max_value def forward( self, floats: torch.Tensor, ): floats = floats.clamp(self.min_value, self.max_value) normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) # Cast floats to same type as embedder embedder_dtype = next(self.time_positional_embedding.parameters()).dtype normalized_floats = normalized_floats.to(embedder_dtype) embedding = self.time_positional_embedding(normalized_floats) float_embeds = embedding.view(-1, 1, self.number_embedding_dim) return float_embeds class StableAudioProjectionModel(ModelMixin, ConfigMixin): """ A simple linear projection model to map the conditioning values to a shared latent space. Args: text_encoder_dim (`int`): Dimensionality of the text embeddings from the text encoder (T5). conditioning_dim (`int`): Dimensionality of the output conditioning tensors. min_value (`int`): The minimum value of the seconds number conditioning modules. max_value (`int`): The maximum value of the seconds number conditioning modules """ @register_to_config def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): super().__init__() self.text_projection = ( nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) ) self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) def forward( self, text_hidden_states: Optional[torch.Tensor] = None, start_seconds: Optional[torch.Tensor] = None, end_seconds: Optional[torch.Tensor] = None, ): text_hidden_states = ( text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) ) seconds_start_hidden_states = ( start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) ) seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) return StableAudioProjectionModelOutput( text_hidden_states=text_hidden_states, seconds_start_hidden_states=seconds_start_hidden_states, seconds_end_hidden_states=seconds_end_hidden_states, ) ================================================ FILE: src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py ================================================ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import torch from transformers import ( T5EncoderModel, T5Tokenizer, T5TokenizerFast, ) from ...models import AutoencoderOobleck, StableAudioDiTModel from ...models.embeddings import get_1d_rotary_pos_embed from ...schedulers import EDMDPMSolverMultistepScheduler from ...utils import ( logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from .modeling_stable_audio import StableAudioProjectionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import scipy >>> import torch >>> import soundfile as sf >>> from diffusers import StableAudioPipeline >>> repo_id = "stabilityai/stable-audio-open-1.0" >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> # define the prompts >>> prompt = "The sound of a hammer hitting a wooden surface." >>> negative_prompt = "Low quality." >>> # set the seed for generator >>> generator = torch.Generator("cuda").manual_seed(0) >>> # run the generation >>> audio = pipe( ... prompt, ... negative_prompt=negative_prompt, ... num_inference_steps=200, ... audio_end_in_s=10.0, ... num_waveforms_per_prompt=3, ... generator=generator, ... ).audios >>> output = audio[0].T.float().cpu().numpy() >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate) ``` """ class StableAudioPipeline(DiffusionPipeline): r""" Pipeline for text-to-audio generation using StableAudio. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderOobleck`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.T5EncoderModel`]): Frozen text-encoder. StableAudio uses the encoder of [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant. projection_model ([`StableAudioProjectionModel`]): A trained model used to linearly project the hidden-states from the text encoder model and the start and end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to give the input to the transformer model. tokenizer ([`~transformers.T5Tokenizer`]): Tokenizer to tokenize text for the frozen text-encoder. transformer ([`StableAudioDiTModel`]): A `StableAudioDiTModel` to denoise the encoded audio latents. scheduler ([`EDMDPMSolverMultistepScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded audio latents. """ model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae" def __init__( self, vae: AutoencoderOobleck, text_encoder: T5EncoderModel, projection_model: StableAudioProjectionModel, tokenizer: Union[T5Tokenizer, T5TokenizerFast], transformer: StableAudioDiTModel, scheduler: EDMDPMSolverMultistepScheduler, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, projection_model=projection_model, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, ) self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2 # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def encode_prompt( self, prompt, device, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None, negative_attention_mask: Optional[torch.LongTensor] = None, ): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # 1. Tokenize text text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( f"The following part of your input was truncated because {self.text_encoder.config.model_type} can " f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids.to(device) attention_mask = attention_mask.to(device) # 2. Text encoder forward self.text_encoder.eval() prompt_embeds = self.text_encoder( text_input_ids, attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if do_classifier_free_guidance and negative_prompt is not None: uncond_tokens: List[str] if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # 1. Tokenize text uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_input.input_ids.to(device) negative_attention_mask = uncond_input.attention_mask.to(device) # 2. Text encoder forward self.text_encoder.eval() negative_prompt_embeds = self.text_encoder( uncond_input_ids, attention_mask=negative_attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if negative_attention_mask is not None: # set the masked tokens to the null embed negative_prompt_embeds = torch.where( negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0 ) # 3. Project prompt_embeds and negative_prompt_embeds if do_classifier_free_guidance and negative_prompt_embeds is not None: # For classifier free guidance, we need to do two forward passes. # Here we concatenate the negative and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if attention_mask is not None and negative_attention_mask is None: negative_attention_mask = torch.ones_like(attention_mask) elif attention_mask is None and negative_attention_mask is not None: attention_mask = torch.ones_like(negative_attention_mask) if attention_mask is not None: attention_mask = torch.cat([negative_attention_mask, attention_mask]) prompt_embeds = self.projection_model( text_hidden_states=prompt_embeds, ).text_hidden_states if attention_mask is not None: prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) return prompt_embeds def encode_duration( self, audio_start_in_s, audio_end_in_s, device, do_classifier_free_guidance, batch_size, ): audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s] audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s] if len(audio_start_in_s) == 1: audio_start_in_s = audio_start_in_s * batch_size if len(audio_end_in_s) == 1: audio_end_in_s = audio_end_in_s * batch_size # Cast the inputs to floats audio_start_in_s = [float(x) for x in audio_start_in_s] audio_start_in_s = torch.tensor(audio_start_in_s).to(device) audio_end_in_s = [float(x) for x in audio_end_in_s] audio_end_in_s = torch.tensor(audio_end_in_s).to(device) projection_output = self.projection_model( start_seconds=audio_start_in_s, end_seconds=audio_end_in_s, ) seconds_start_hidden_states = projection_output.seconds_start_hidden_states seconds_end_hidden_states = projection_output.seconds_end_hidden_states # For classifier free guidance, we need to do two forward passes. # Here we repeat the audio hidden states to avoid doing two forward passes if do_classifier_free_guidance: seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) return seconds_start_hidden_states, seconds_end_hidden_states # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, audio_start_in_s, audio_end_in_s, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, attention_mask=None, negative_attention_mask=None, initial_audio_waveforms=None, initial_audio_sampling_rate=None, ): if audio_end_in_s < audio_start_in_s: raise ValueError( f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but " ) if ( audio_start_in_s < self.projection_model.config.min_value or audio_start_in_s > self.projection_model.config.max_value ): raise ValueError( f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " f"is {audio_start_in_s}." ) if ( audio_end_in_s < self.projection_model.config.min_value or audio_end_in_s > self.projection_model.config.max_value ): raise ValueError( f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " f"is {audio_end_in_s}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and (prompt_embeds is None): raise ValueError( "Provide either `prompt`, or `prompt_embeds`. Cannot leave" "`prompt` undefined without specifying `prompt_embeds`." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]: raise ValueError( "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" ) if initial_audio_sampling_rate is None and initial_audio_waveforms is not None: raise ValueError( "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`." ) if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate: raise ValueError( f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`." "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. " ) def prepare_latents( self, batch_size, num_channels_vae, sample_size, dtype, device, generator, latents=None, initial_audio_waveforms=None, num_waveforms_per_prompt=None, audio_channels=None, ): shape = (batch_size, num_channels_vae, sample_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma # encode the initial audio for use by the model if initial_audio_waveforms is not None: # check dimension if initial_audio_waveforms.ndim == 2: initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1) elif initial_audio_waveforms.ndim != 3: raise ValueError( f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions" ) audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) # check num_channels if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2: initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1) elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1: initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True) if initial_audio_waveforms.shape[:2] != audio_shape[:2]: raise ValueError( f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`" ) # crop or pad audio_length = initial_audio_waveforms.shape[-1] if audio_length < audio_vae_length: logger.warning( f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded." ) elif audio_length > audio_vae_length: logger.warning( f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped." ) audio = initial_audio_waveforms.new_zeros(audio_shape) audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length] encoded_audio = self.vae.encode(audio).latent_dist.sample(generator) encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1)) latents = encoded_audio + latents return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, audio_end_in_s: Optional[float] = None, audio_start_in_s: Optional[float] = 0.0, num_inference_steps: int = 100, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_waveforms_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, initial_audio_waveforms: Optional[torch.Tensor] = None, initial_audio_sampling_rate: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None, negative_attention_mask: Optional[torch.LongTensor] = None, return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: Optional[int] = 1, output_type: Optional[str] = "pt", ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. audio_end_in_s (`float`, *optional*, defaults to 47.55): Audio end index in seconds. audio_start_in_s (`float`, *optional*, defaults to 0): Audio start index in seconds. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.0): A higher guidance scale value encourages the model to generate audio that is closely linked to the text `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in audio generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_waveforms_per_prompt (`int`, *optional*, defaults to 1): The number of waveforms to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. initial_audio_waveforms (`torch.Tensor`, *optional*): Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size` corresponds to the number of prompts passed to the model. initial_audio_sampling_rate (`int`, *optional*): Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model. prompt_embeds (`torch.Tensor`, *optional*): Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from `negative_prompt` input argument. attention_mask (`torch.LongTensor`, *optional*): Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will be computed from `prompt` input argument. negative_attention_mask (`torch.LongTensor`, *optional*): Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion model (LDM) output. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated audio. """ # 0. Convert audio input length from seconds to latent length downsample_ratio = self.vae.hop_length max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate if audio_end_in_s is None: audio_end_in_s = max_audio_length_in_s if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: raise ValueError( f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." ) waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate) waveform_length = int(self.transformer.config.sample_size) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, audio_start_in_s, audio_end_in_s, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask, initial_audio_waveforms, initial_audio_sampling_rate, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self.encode_prompt( prompt, device, do_classifier_free_guidance, negative_prompt, prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask, ) # Encode duration seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration( audio_start_in_s, audio_end_in_s, device, do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None), batch_size, ) # Create text_audio_duration_embeds and audio_duration_embeds text_audio_duration_embeds = torch.cat( [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1 ) audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2) # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and # to concatenate it to the embeddings if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: negative_text_audio_duration_embeds = torch.zeros_like( text_audio_duration_embeds, device=text_audio_duration_embeds.device ) text_audio_duration_embeds = torch.cat( [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0 ) audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0) bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) text_audio_duration_embeds = text_audio_duration_embeds.view( bs_embed * num_waveforms_per_prompt, seq_len, hidden_size ) audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) audio_duration_embeds = audio_duration_embeds.view( bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1] ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_vae = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_waveforms_per_prompt, num_channels_vae, waveform_length, text_audio_duration_embeds.dtype, device, generator, latents, initial_audio_waveforms, num_waveforms_per_prompt, audio_channels=self.vae.config.audio_channels, ) # 6. Prepare extra step kwargs extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare rotary positional embedding rotary_embedding = get_1d_rotary_pos_embed( self.rotary_embed_dim, latents.shape[2] + audio_duration_embeds.shape[1], use_real=True, repeat_interleave_real=False, ) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.transformer( latent_model_input, t.unsqueeze(0), encoder_hidden_states=text_audio_duration_embeds, global_hidden_states=audio_duration_embeds, rotary_embedding=rotary_embedding, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 9. Post-processing if not output_type == "latent": audio = self.vae.decode(latents).sample else: return AudioPipelineOutput(audios=latents) audio = audio[:, :, waveform_start:waveform_end] if output_type == "np": audio = audio.cpu().float().numpy() self.maybe_free_model_hooks() if not return_dict: return (audio,) return AudioPipelineOutput(audios=audio) ================================================ FILE: src/diffusers/pipelines/stable_cascade/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_stable_cascade"] = ["StableCascadeDecoderPipeline"] _import_structure["pipeline_stable_cascade_combined"] = ["StableCascadeCombinedPipeline"] _import_structure["pipeline_stable_cascade_prior"] = ["StableCascadePriorPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_stable_cascade import StableCascadeDecoderPipeline from .pipeline_stable_cascade_combined import StableCascadeCombinedPipeline from .pipeline_stable_cascade_prior import StableCascadePriorPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Dict, List, Optional, Union import torch from transformers import CLIPTextModel, CLIPTokenizer from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler from ...utils import is_torch_version, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline >>> prior_pipe = StableCascadePriorPipeline.from_pretrained( ... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16 ... ).to("cuda") >>> gen_pipe = StableCascadeDecoderPipeline.from_pretrain( ... "stabilityai/stable-cascade", torch_dtype=torch.float16 ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prior_output = pipe(prompt) >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt) ``` """ class StableCascadeDecoderPipeline(DiffusionPipeline): """ Pipeline for generating images from the Stable Cascade model. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: tokenizer (`CLIPTokenizer`): The CLIP tokenizer. text_encoder (`CLIPTextModel`): The CLIP text encoder. decoder ([`StableCascadeUNet`]): The Stable Cascade decoder unet. vqgan ([`PaellaVQModel`]): The VQGAN model. scheduler ([`DDPMWuerstchenScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. latent_dim_scale (float, `optional`, defaults to 10.67): Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and width=int(24*10.67)=256 in order to match the training conditions. """ unet_name = "decoder" text_encoder_name = "text_encoder" model_cpu_offload_seq = "text_encoder->decoder->vqgan" _callback_tensor_inputs = [ "latents", "prompt_embeds_pooled", "negative_prompt_embeds", "image_embeddings", ] def __init__( self, decoder: StableCascadeUNet, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, scheduler: DDPMWuerstchenScheduler, vqgan: PaellaVQModel, latent_dim_scale: float = 10.67, ) -> None: super().__init__() self.register_modules( decoder=decoder, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=scheduler, vqgan=vqgan, ) self.register_to_config(latent_dim_scale=latent_dim_scale) def prepare_latents( self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler ): _, channels, height, width = image_embeddings.shape latents_shape = ( batch_size * num_images_per_prompt, 4, int(height * self.config.latent_dim_scale), int(width * self.config.latent_dim_scale), ) if latents is None: latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def encode_prompt( self, device, batch_size, num_images_per_prompt, do_classifier_free_guidance, prompt=None, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_pooled: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, ): if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] attention_mask = attention_mask[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True ) prompt_embeds = text_encoder_output.hidden_states[-1] if prompt_embeds_pooled is None: prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) if negative_prompt_embeds is None and do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds_text_encoder_output = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device), output_hidden_states=True, ) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1] negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) seq_len = negative_prompt_embeds_pooled.shape[1] negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to( dtype=self.text_encoder.dtype, device=device ) negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view( batch_size * num_images_per_prompt, seq_len, -1 ) # done duplicates return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled def check_inputs( self, prompt, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeddings: Union[torch.Tensor, List[torch.Tensor]], prompt: Union[str, List[str]] = None, num_inference_steps: int = 10, guidance_scale: float = 0.0, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_pooled: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): """ Function invoked when calling the pipeline for generation. Args: image_embedding (`torch.Tensor` or `List[torch.Tensor]`): Image Embeddings either extracted from an image or generated by a Prior Model. prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. num_inference_steps (`int`, *optional*, defaults to 12): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 0.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `decoder_guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. prompt_embeds_pooled (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_embeds_pooled (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input argument. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image embeddings. """ # 0. Define commonly used variables device = self._execution_device dtype = self.decoder.dtype self._guidance_scale = guidance_scale if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16: raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.") # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) if isinstance(image_embeddings, list): image_embeddings = torch.cat(image_embeddings, dim=0) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Compute the effective number of images per prompt # We must account for the fact that the image embeddings from the prior can be generated with num_images_per_prompt > 1 # This results in a case where a single prompt is associated with multiple image embeddings # Divide the number of image embeddings by the batch size to determine if this is the case. num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size) # 2. Encode caption if prompt_embeds is None and negative_prompt_embeds is None: _, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt( prompt=prompt, device=device, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, ) # The pooled embeds from the prior are pooled again before being passed to the decoder prompt_embeds_pooled = ( torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled]) if self.do_classifier_free_guidance else prompt_embeds_pooled ) effnet = ( torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) if self.do_classifier_free_guidance else image_embeddings ) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latents latents = self.prepare_latents( batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler ) # 6. Run denoising loop self._num_timesteps = len(timesteps[:-1]) for i, t in enumerate(self.progress_bar(timesteps[:-1])): timestep_ratio = t.expand(latents.size(0)).to(dtype) # 7. Denoise latents predicted_latents = self.decoder( sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio, clip_text_pooled=prompt_embeds_pooled, effnet=effnet, return_dict=False, )[0] # 8. Check for classifier free guidance and apply it if self.do_classifier_free_guidance: predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale) # 9. Renoise latents to next timestep latents = self.scheduler.step( model_output=predicted_latents, timestep=timestep_ratio, sample=latents, generator=generator, ).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" ) if not output_type == "latent": # 10. Scale and decode the image latents with vq-vae latents = self.vqgan.config.scale_factor * latents images = self.vqgan.decode(latents).sample.clamp(0, 1) if output_type == "np": images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work elif output_type == "pil": images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work images = self.numpy_to_pil(images) else: images = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return images return ImagePipelineOutput(images) ================================================ FILE: src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Dict, List, Optional, Union import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler from ...utils import is_torch_version, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel from .pipeline_stable_cascade import StableCascadeDecoderPipeline from .pipeline_stable_cascade_prior import StableCascadePriorPipeline TEXT2IMAGE_EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableCascadeCombinedPipeline >>> pipe = StableCascadeCombinedPipeline.from_pretrained( ... "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16 ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> images = pipe(prompt=prompt) ``` """ class StableCascadeCombinedPipeline(DiffusionPipeline): """ Combined Pipeline for text-to-image generation using Stable Cascade. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: tokenizer (`CLIPTokenizer`): The decoder tokenizer to be used for text inputs. text_encoder (`CLIPTextModel`): The decoder text encoder to be used for text inputs. decoder (`StableCascadeUNet`): The decoder model to be used for decoder image generation pipeline. scheduler (`DDPMWuerstchenScheduler`): The scheduler to be used for decoder image generation pipeline. vqgan (`PaellaVQModel`): The VQGAN model to be used for decoder image generation pipeline. feature_extractor ([`~transformers.CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `image_encoder`. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). prior_prior (`StableCascadeUNet`): The prior model to be used for prior pipeline. prior_scheduler (`DDPMWuerstchenScheduler`): The scheduler to be used for prior pipeline. """ _load_connected_pipes = True _optional_components = ["prior_feature_extractor", "prior_image_encoder"] def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, decoder: StableCascadeUNet, scheduler: DDPMWuerstchenScheduler, vqgan: PaellaVQModel, prior_prior: StableCascadeUNet, prior_text_encoder: CLIPTextModel, prior_tokenizer: CLIPTokenizer, prior_scheduler: DDPMWuerstchenScheduler, prior_feature_extractor: Optional[CLIPImageProcessor] = None, prior_image_encoder: Optional[CLIPVisionModelWithProjection] = None, ): super().__init__() self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, decoder=decoder, scheduler=scheduler, vqgan=vqgan, prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, prior_prior=prior_prior, prior_scheduler=prior_scheduler, prior_feature_extractor=prior_feature_extractor, prior_image_encoder=prior_image_encoder, ) self.prior_pipe = StableCascadePriorPipeline( prior=prior_prior, text_encoder=prior_text_encoder, tokenizer=prior_tokenizer, scheduler=prior_scheduler, image_encoder=prior_image_encoder, feature_extractor=prior_feature_extractor, ) self.decoder_pipe = StableCascadeDecoderPipeline( text_encoder=text_encoder, tokenizer=tokenizer, decoder=decoder, scheduler=scheduler, vqgan=vqgan, ) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. """ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) def set_progress_bar_config(self, **kwargs): self.prior_pipe.set_progress_bar_config(**kwargs) self.decoder_pipe.set_progress_bar_config(**kwargs) @torch.no_grad() @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None, height: int = 512, width: int = 512, prior_num_inference_steps: int = 60, prior_guidance_scale: float = 4.0, num_inference_steps: int = 12, decoder_guidance_scale: float = 0.0, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_pooled: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation for the prior and decoder. images (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, *optional*): The images to guide the image generation for the prior. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. prompt_embeds_pooled (`torch.Tensor`, *optional*): Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_embeds_pooled (`torch.Tensor`, *optional*): Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60): The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. For more specific timestep spacing, you can pass customized `prior_timesteps` num_inference_steps (`int`, *optional*, defaults to 12): The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. For more specific timestep spacing, you can pass customized `timesteps` decoder_guidance_scale (`float`, *optional*, defaults to 0.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. prior_callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. prior_callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ dtype = self.decoder_pipe.decoder.dtype if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16: raise ValueError( "`StableCascadeCombinedPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype." ) prior_outputs = self.prior_pipe( prompt=prompt if prompt_embeds is None else None, images=images, height=height, width=width, num_inference_steps=prior_num_inference_steps, guidance_scale=prior_guidance_scale, negative_prompt=negative_prompt if negative_prompt_embeds is None else None, prompt_embeds=prompt_embeds, prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, output_type="pt", return_dict=True, callback_on_step_end=prior_callback_on_step_end, callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, ) image_embeddings = prior_outputs.image_embeddings prompt_embeds = prior_outputs.get("prompt_embeds", None) prompt_embeds_pooled = prior_outputs.get("prompt_embeds_pooled", None) negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None) negative_prompt_embeds_pooled = prior_outputs.get("negative_prompt_embeds_pooled", None) outputs = self.decoder_pipe( image_embeddings=image_embeddings, prompt=prompt if prompt_embeds is None else None, num_inference_steps=num_inference_steps, guidance_scale=decoder_guidance_scale, negative_prompt=negative_prompt if negative_prompt_embeds is None else None, prompt_embeds=prompt_embeds, prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, generator=generator, output_type=output_type, return_dict=return_dict, callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) return outputs ================================================ FILE: src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from math import ceil from typing import Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler from ...utils import BaseOutput, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableCascadePriorPipeline >>> prior_pipe = StableCascadePriorPipeline.from_pretrained( ... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16 ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prior_output = pipe(prompt) ``` """ @dataclass class StableCascadePriorPipelineOutput(BaseOutput): """ Output class for WuerstchenPriorPipeline. Args: image_embeddings (`torch.Tensor` or `np.ndarray`) Prior image embeddings for text prompt prompt_embeds (`torch.Tensor`): Text embeddings for the prompt. negative_prompt_embeds (`torch.Tensor`): Text embeddings for the negative prompt. """ image_embeddings: Union[torch.Tensor, np.ndarray] prompt_embeds: Union[torch.Tensor, np.ndarray] prompt_embeds_pooled: Union[torch.Tensor, np.ndarray] negative_prompt_embeds: Union[torch.Tensor, np.ndarray] negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray] class StableCascadePriorPipeline(DiffusionPipeline): """ Pipeline for generating image prior for Stable Cascade. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`StableCascadeUNet`]): The Stable Cascade prior to approximate the image embedding from the text and/or image embedding. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). feature_extractor ([`~transformers.CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `image_encoder`. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`DDPMWuerstchenScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. resolution_multiple ('float', *optional*, defaults to 42.67): Default resolution for multiple images generated. """ unet_name = "prior" text_encoder_name = "text_encoder" model_cpu_offload_seq = "image_encoder->text_encoder->prior" _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"] def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, prior: StableCascadeUNet, scheduler: DDPMWuerstchenScheduler, resolution_multiple: float = 42.67, feature_extractor: Optional[CLIPImageProcessor] = None, image_encoder: Optional[CLIPVisionModelWithProjection] = None, ) -> None: super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, image_encoder=image_encoder, feature_extractor=feature_extractor, prior=prior, scheduler=scheduler, ) self.register_to_config(resolution_multiple=resolution_multiple) def prepare_latents( self, batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, scheduler ): latent_shape = ( num_images_per_prompt * batch_size, self.prior.config.in_channels, ceil(height / self.config.resolution_multiple), ceil(width / self.config.resolution_multiple), ) if latents is None: latents = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != latent_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latent_shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def encode_prompt( self, device, batch_size, num_images_per_prompt, do_classifier_free_guidance, prompt=None, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_pooled: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, ): if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] attention_mask = attention_mask[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True ) prompt_embeds = text_encoder_output.hidden_states[-1] if prompt_embeds_pooled is None: prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) if negative_prompt_embeds is None and do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds_text_encoder_output = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device), output_hidden_states=True, ) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1] negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) seq_len = negative_prompt_embeds_pooled.shape[1] negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to( dtype=self.text_encoder.dtype, device=device ) negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view( batch_size * num_images_per_prompt, seq_len, -1 ) # done duplicates return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled def encode_image(self, images, device, dtype, batch_size, num_images_per_prompt): image_embeds = [] for image in images: image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embed = self.image_encoder(image).image_embeds.unsqueeze(1) image_embeds.append(image_embed) image_embeds = torch.cat(image_embeds, dim=1) image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) negative_image_embeds = torch.zeros_like(image_embeds) return image_embeds, negative_image_embeds def check_inputs( self, prompt, images=None, image_embeds=None, negative_prompt=None, prompt_embeds=None, prompt_embeds_pooled=None, negative_prompt_embeds=None, negative_prompt_embeds_pooled=None, callback_on_step_end_tensor_inputs=None, ): if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and prompt_embeds_pooled is None: raise ValueError( "If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" ) if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" ) if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None: if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape: raise ValueError( "`prompt_embeds_pooled` and `negative_prompt_embeds_pooled` must have the same shape when passed" f"directly, but got: `prompt_embeds_pooled` {prompt_embeds_pooled.shape} !=" f"`negative_prompt_embeds_pooled` {negative_prompt_embeds_pooled.shape}." ) if image_embeds is not None and images is not None: raise ValueError( f"Cannot forward both `images`: {images} and `image_embeds`: {image_embeds}. Please make sure to" " only forward one of the two." ) if images: for i, image in enumerate(images): if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): raise TypeError( f"'images' must contain images of type 'torch.Tensor' or 'PIL.Image.Image, but got" f"{type(image)} for image number {i}." ) @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps def get_timestep_ratio_conditioning(self, t, alphas_cumprod): s = torch.tensor([0.003]) clamp_range = [0, 1] min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 var = alphas_cumprod[t] var = var.clamp(*clamp_range) s, min_var = s.to(var.device), min_var.to(var.device) ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s return ratio @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None, height: int = 1024, width: int = 1024, num_inference_steps: int = 20, timesteps: List[float] = None, guidance_scale: float = 4.0, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds_pooled: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds_pooled: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pt", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to 1024): The height in pixels of the generated image. width (`int`, *optional*, defaults to 1024): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 60): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 8.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `decoder_guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. prompt_embeds_pooled (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_embeds_pooled (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input argument. image_embeds (`torch.Tensor`, *optional*): Pre-generated image embeddings. Can be used to easily tweak image inputs, *e.g.* prompt weighting. If not provided, image embeddings will be generated from `image` input argument if existing. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`StableCascadePriorPipelineOutput`] or `tuple` [`StableCascadePriorPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image embeddings. """ # 0. Define commonly used variables device = self._execution_device dtype = next(self.prior.parameters()).dtype self._guidance_scale = guidance_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, images=images, image_embeds=image_embeds, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) # 2. Encode caption + images ( prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled, ) = self.encode_prompt( prompt=prompt, device=device, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, ) if images is not None: image_embeds_pooled, uncond_image_embeds_pooled = self.encode_image( images=images, device=device, dtype=dtype, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, ) elif image_embeds is not None: image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) uncond_image_embeds_pooled = torch.zeros_like(image_embeds_pooled) else: image_embeds_pooled = torch.zeros( batch_size * num_images_per_prompt, 1, self.prior.config.clip_image_in_channels, device=device, dtype=dtype, ) uncond_image_embeds_pooled = torch.zeros( batch_size * num_images_per_prompt, 1, self.prior.config.clip_image_in_channels, device=device, dtype=dtype, ) if self.do_classifier_free_guidance: image_embeds = torch.cat([image_embeds_pooled, uncond_image_embeds_pooled], dim=0) else: image_embeds = image_embeds_pooled # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_encoder_hidden_states = ( torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds ) text_encoder_pooled = ( torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled]) if negative_prompt_embeds is not None else prompt_embeds_pooled ) # 4. Prepare and set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latents latents = self.prepare_latents( batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, self.scheduler ) if isinstance(self.scheduler, DDPMWuerstchenScheduler): timesteps = timesteps[:-1] else: if self.scheduler.config.clip_sample: self.scheduler.config.clip_sample = False # disample sample clipping logger.warning(" set `clip_sample` to be False") # 6. Run denoising loop if hasattr(self.scheduler, "betas"): alphas = 1.0 - self.scheduler.betas alphas_cumprod = torch.cumprod(alphas, dim=0) else: alphas_cumprod = [] self._num_timesteps = len(timesteps) for i, t in enumerate(self.progress_bar(timesteps)): if not isinstance(self.scheduler, DDPMWuerstchenScheduler): if len(alphas_cumprod) > 0: timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod) timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device) else: timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype) else: timestep_ratio = t.expand(latents.size(0)).to(dtype) # 7. Denoise image embeddings predicted_image_embedding = self.prior( sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio, clip_text_pooled=text_encoder_pooled, clip_text=text_encoder_hidden_states, clip_img=image_embeds, return_dict=False, )[0] # 8. Check for classifier free guidance and apply it if self.do_classifier_free_guidance: predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) predicted_image_embedding = torch.lerp( predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale ) # 9. Renoise latents to next timestep if not isinstance(self.scheduler, DDPMWuerstchenScheduler): timestep_ratio = t latents = self.scheduler.step( model_output=predicted_image_embedding, timestep=timestep_ratio, sample=latents, generator=generator ).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # Offload all models self.maybe_free_model_hooks() if output_type == "np": latents = latents.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work prompt_embeds = prompt_embeds.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work negative_prompt_embeds = ( negative_prompt_embeds.cpu().float().numpy() if negative_prompt_embeds is not None else None ) # float() as bfloat16-> numpy doesnt work if not return_dict: return ( latents, prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled, ) return StableCascadePriorPipelineOutput( image_embeddings=latents, prompt_embeds=prompt_embeds, prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, ) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/README.md ================================================ # Stable Diffusion ## Overview Stable Diffusion was proposed in [Stable Diffusion Announcement](https://stability.ai/blog/stable-diffusion-announcement) by Patrick Esser and Robin Rombach and the Stability AI team. The summary of the model is the following: *Stable Diffusion is a text-to-image model that will empower billions of people to create stunning art within seconds. It is a breakthrough in speed and quality meaning that it can run on consumer GPUs. You can see some of the amazing output that has been created by this model without pre or post-processing on this page. The model itself builds upon the work of the team at CompVis and Runway in their widely used latent diffusion model combined with insights from the conditional diffusion models by our lead generative AI developer Katherine Crowson, Dall-E 2 by Open AI, Imagen by Google Brain and many others. We are delighted that AI media generation is a cooperative field and hope it can continue this way to bring the gift of creativity to all.* ## Tips: - Stable Diffusion has the same architecture as [Latent Diffusion](https://arxiv.org/abs/2112.10752) but uses a frozen CLIP Text Encoder instead of training the text encoder jointly with the diffusion model. - An in-detail explanation of the Stable Diffusion model can be found under [Stable Diffusion with 🧨 Diffusers](https://huggingface.co/blog/stable_diffusion). - If you don't want to rely on the Hugging Face Hub and having to pass a authentication token, you can download the weights with `git lfs install; git clone https://huggingface.co/runwayml/stable-diffusion-v1-5` and instead pass the local path to the cloned folder to `from_pretrained` as shown below. - Stable Diffusion can work with a variety of different samplers as is shown below. ## Available Pipelines: | Pipeline | Tasks | Colab |---|---|:---:| | [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | [pipeline_stable_diffusion_img2img](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [pipeline_stable_diffusion_inpaint](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) ## Examples: ### Using Stable Diffusion without being logged into the Hub. If you want to download the model weights using a single Python line, you need to be logged in via `huggingface-cli login`. ```python from diffusers import DiffusionPipeline pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") ``` This however can make it difficult to build applications on top of `diffusers` as you will always have to pass the token around. A potential way to solve this issue is by downloading the weights to a local path `"./stable-diffusion-v1-5"`: ``` git lfs install git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 ``` and simply passing the local path to `from_pretrained`: ```python from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5") ``` ### Text-to-Image with default PLMS scheduler ```python # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` ### Text-to-Image with DDIM scheduler ```python # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline, DDIMScheduler scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", scheduler=scheduler, ).to("cuda") prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` ### Text-to-Image with K-LMS scheduler ```python # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", scheduler=lms, ).to("cuda") prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` ### CycleDiffusion using Stable Diffusion and DDIM scheduler ```python import requests import torch from PIL import Image from io import BytesIO from diffusers import CycleDiffusionPipeline, DDIMScheduler # load the scheduler. CycleDiffusion only supports stochastic schedulers. # load the pipeline # make sure you're logged in with `huggingface-cli login` model_id_or_path = "CompVis/stable-diffusion-v1-4" scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") # let's download an initial image url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png" response = requests.get(url) init_image = Image.open(BytesIO(response.content)).convert("RGB") init_image = init_image.resize((512, 512)) init_image.save("horse.png") # let's specify a prompt source_prompt = "An astronaut riding a horse" prompt = "An astronaut riding an elephant" # call the pipeline image = pipe( prompt=prompt, source_prompt=source_prompt, image=init_image, num_inference_steps=100, eta=0.1, strength=0.8, guidance_scale=2, source_guidance_scale=1, ).images[0] image.save("horse_to_elephant.png") # let's try another example # See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png" response = requests.get(url) init_image = Image.open(BytesIO(response.content)).convert("RGB") init_image = init_image.resize((512, 512)) init_image.save("black.png") source_prompt = "A black colored car" prompt = "A blue colored car" # call the pipeline torch.manual_seed(0) image = pipe( prompt=prompt, source_prompt=source_prompt, image=init_image, num_inference_steps=100, eta=0.1, strength=0.85, guidance_scale=3, source_guidance_scale=1, ).images[0] image.save("black_to_blue.png") ``` ================================================ FILE: src/diffusers/pipelines/stable_diffusion/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_flax_available, is_k_diffusion_available, is_k_diffusion_version, is_onnx_available, is_torch_available, is_transformers_available, is_transformers_version, ) _dummy_objects = {} _additional_imports = {} _import_structure = {"pipeline_output": ["StableDiffusionPipelineOutput"]} if is_transformers_available() and is_flax_available(): _import_structure["pipeline_output"].extend(["FlaxStableDiffusionPipelineOutput"]) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["clip_image_project_model"] = ["CLIPImageProjection"] _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"] _import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"] _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"] _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"] _import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"] _import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"] _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"] _import_structure["pipeline_stable_diffusion_instruct_pix2pix"] = ["StableDiffusionInstructPix2PixPipeline"] _import_structure["pipeline_stable_diffusion_latent_upscale"] = ["StableDiffusionLatentUpscalePipeline"] _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"] _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"] _import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"] _import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"] _import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"] _import_structure["safety_checker"] = ["StableDiffusionSafetyChecker"] _import_structure["stable_unclip_image_normalizer"] = ["StableUnCLIPImageNormalizer"] try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionImageVariationPipeline, ) _dummy_objects.update({"StableDiffusionImageVariationPipeline": StableDiffusionImageVariationPipeline}) else: _import_structure["pipeline_stable_diffusion_image_variation"] = ["StableDiffusionImageVariationPipeline"] try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionDepth2ImgPipeline, ) _dummy_objects.update( { "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, } ) else: _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"] try: if not (is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_onnx_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) else: _import_structure["pipeline_onnx_stable_diffusion"] = [ "OnnxStableDiffusionPipeline", "StableDiffusionOnnxPipeline", ] _import_structure["pipeline_onnx_stable_diffusion_img2img"] = ["OnnxStableDiffusionImg2ImgPipeline"] _import_structure["pipeline_onnx_stable_diffusion_inpaint"] = ["OnnxStableDiffusionInpaintPipeline"] _import_structure["pipeline_onnx_stable_diffusion_inpaint_legacy"] = ["OnnxStableDiffusionInpaintPipelineLegacy"] _import_structure["pipeline_onnx_stable_diffusion_upscale"] = ["OnnxStableDiffusionUpscalePipeline"] if is_transformers_available() and is_flax_available(): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState _additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState}) _import_structure["pipeline_flax_stable_diffusion"] = ["FlaxStableDiffusionPipeline"] _import_structure["pipeline_flax_stable_diffusion_img2img"] = ["FlaxStableDiffusionImg2ImgPipeline"] _import_structure["pipeline_flax_stable_diffusion_inpaint"] = ["FlaxStableDiffusionInpaintPipeline"] _import_structure["safety_checker_flax"] = ["FlaxStableDiffusionSafetyChecker"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .clip_image_project_model import CLIPImageProjection from .pipeline_stable_diffusion import ( StableDiffusionPipeline, StableDiffusionPipelineOutput, ) from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_instruct_pix2pix import ( StableDiffusionInstructPix2PixPipeline, ) from .pipeline_stable_diffusion_latent_upscale import ( StableDiffusionLatentUpscalePipeline, ) from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_unclip import StableUnCLIPPipeline from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline from .safety_checker import StableDiffusionSafetyChecker from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionImageVariationPipeline, ) else: from .pipeline_stable_diffusion_image_variation import ( StableDiffusionImageVariationPipeline, ) try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import StableDiffusionDepth2ImgPipeline else: from .pipeline_stable_diffusion_depth2img import ( StableDiffusionDepth2ImgPipeline, ) try: if not (is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_onnx_objects import * else: from .pipeline_onnx_stable_diffusion import ( OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline, ) from .pipeline_onnx_stable_diffusion_img2img import ( OnnxStableDiffusionImg2ImgPipeline, ) from .pipeline_onnx_stable_diffusion_inpaint import ( OnnxStableDiffusionInpaintPipeline, ) from .pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, ) try: if not (is_transformers_available() and is_flax_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_flax_objects import * else: from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .pipeline_flax_stable_diffusion_img2img import ( FlaxStableDiffusionImg2ImgPipeline, ) from .pipeline_flax_stable_diffusion_inpaint import ( FlaxStableDiffusionInpaintPipeline, ) from .pipeline_output import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) for name, value in _additional_imports.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py ================================================ # Copyright 2024 The GLIGEN Authors and HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin class CLIPImageProjection(ModelMixin, ConfigMixin): @register_to_config def __init__(self, hidden_size: int = 768): super().__init__() self.hidden_size = hidden_size self.project = nn.Linear(self.hidden_size, self.hidden_size, bias=False) def forward(self, x): return self.project(x) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Conversion script for the Stable Diffusion checkpoints.""" import re from contextlib import nullcontext from io import BytesIO from typing import Dict, Optional, Union import requests import torch import yaml from transformers import ( AutoFeatureExtractor, BertTokenizerFast, CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionConfig, CLIPVisionModelWithProjection, ) from ...models import ( AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel, ) from ...schedulers import ( DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UnCLIPScheduler, ) from ...utils import is_accelerate_available, logging from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..paint_by_example import PaintByExampleImageEncoder from ..pipeline_utils import DiffusionPipeline from .safety_checker import StableDiffusionSafetyChecker from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer if is_accelerate_available(): from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device logger = logging.get_logger(__name__) # pylint: disable=invalid-name def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. """ if n_shave_prefix_segments >= 0: return ".".join(path.split(".")[n_shave_prefix_segments:]) else: return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item.replace("in_layers.0", "norm1") new_item = new_item.replace("in_layers.2", "conv1") new_item = new_item.replace("out_layers.0", "norm2") new_item = new_item.replace("out_layers.3", "conv2") new_item = new_item.replace("emb_layers.1", "time_emb_proj") new_item = new_item.replace("skip_connection", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("nin_shortcut", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item # new_item = new_item.replace('norm.weight', 'group_norm.weight') # new_item = new_item.replace('norm.bias', 'group_norm.bias') # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") new_item = new_item.replace("q.weight", "to_q.weight") new_item = new_item.replace("q.bias", "to_q.bias") new_item = new_item.replace("k.weight", "to_k.weight") new_item = new_item.replace("k.bias", "to_k.bias") new_item = new_item.replace("v.weight", "to_v.weight") new_item = new_item.replace("v.bias", "to_v.bias") new_item = new_item.replace("proj_out.weight", "to_out.0.weight") new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def assign_to_checkpoint( paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): """ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new checkpoint. """ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." # Splits the attention layers into three variables. if attention_paths_to_split is not None: for path, path_map in attention_paths_to_split.items(): old_tensor = old_checkpoint[path] channels = old_tensor.shape[0] // 3 target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) query, key, value = old_tensor.split(channels // num_heads, dim=1) checkpoint[path_map["query"]] = query.reshape(target_shape) checkpoint[path_map["key"]] = key.reshape(target_shape) checkpoint[path_map["value"]] = value.reshape(target_shape) for path in paths: new_path = path["new"] # These have already been assigned if attention_paths_to_split is not None and new_path in attention_paths_to_split: continue # Global renaming happens here new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") if additional_replacements is not None: for replacement in additional_replacements: new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) shape = old_checkpoint[path["old"]].shape if is_attn_weight and len(shape) == 3: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] elif is_attn_weight and len(shape) == 4: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] def conv_attn_to_linear(checkpoint): keys = list(checkpoint.keys()) attn_keys = ["query.weight", "key.weight", "value.weight"] for key in keys: if ".".join(key.split(".")[-2:]) in attn_keys: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0, 0] elif "proj_attn.weight" in key: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0] def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): """ Creates a config for the diffusers based on the config of the LDM model. """ if controlnet: unet_params = original_config["model"]["params"]["control_stage_config"]["params"] else: if ( "unet_config" in original_config["model"]["params"] and original_config["model"]["params"]["unet_config"] is not None ): unet_params = original_config["model"]["params"]["unet_config"]["params"] else: unet_params = original_config["model"]["params"]["network_config"]["params"] vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 if unet_params["transformer_depth"] is not None: transformer_layers_per_block = ( unet_params["transformer_depth"] if isinstance(unet_params["transformer_depth"], int) else list(unet_params["transformer_depth"]) ) else: transformer_layers_per_block = 1 vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None use_linear_projection = ( unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] class_embed_type = None addition_embed_type = None addition_time_embed_dim = None projection_class_embeddings_input_dim = None context_dim = None if unet_params["context_dim"] is not None: context_dim = ( unet_params["context_dim"] if isinstance(unet_params["context_dim"], int) else unet_params["context_dim"][0] ) if "num_classes" in unet_params: if unet_params["num_classes"] == "sequential": if context_dim in [2048, 1280]: # SDXL addition_embed_type = "text_time" addition_time_embed_dim = 256 else: class_embed_type = "projection" assert "adm_in_channels" in unet_params projection_class_embeddings_input_dim = unet_params["adm_in_channels"] config = { "sample_size": image_size // vae_scale_factor, "in_channels": unet_params["in_channels"], "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), "layers_per_block": unet_params["num_res_blocks"], "cross_attention_dim": context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "addition_embed_type": addition_embed_type, "addition_time_embed_dim": addition_time_embed_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, "transformer_layers_per_block": transformer_layers_per_block, } if "disable_self_attentions" in unet_params: config["only_cross_attention"] = unet_params["disable_self_attentions"] if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): config["num_class_embeds"] = unet_params["num_classes"] if controlnet: config["conditioning_channels"] = unet_params["hint_channels"] else: config["out_channels"] = unet_params["out_channels"] config["up_block_types"] = tuple(up_block_types) return config def create_vae_diffusers_config(original_config, image_size: int): """ Creates a config for the diffusers based on the config of the LDM model. """ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) config = { "sample_size": image_size, "in_channels": vae_params["in_channels"], "out_channels": vae_params["out_ch"], "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), "latent_channels": vae_params["z_channels"], "layers_per_block": vae_params["num_res_blocks"], } return config def create_diffusers_schedular(original_config): schedular = DDIMScheduler( num_train_timesteps=original_config["model"]["params"]["timesteps"], beta_start=original_config["model"]["params"]["linear_start"], beta_end=original_config["model"]["params"]["linear_end"], beta_schedule="scaled_linear", ) return schedular def create_ldm_bert_config(original_config): bert_params = original_config["model"]["params"]["cond_stage_config"]["params"] config = LDMBertConfig( d_model=bert_params.n_embed, encoder_layers=bert_params.n_layer, encoder_ffn_dim=bert_params.n_embed * 4, ) return config def convert_ldm_unet_checkpoint( checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False ): """ Takes a state dict and a config, and returns a converted checkpoint. """ if skip_extract_state_dict: unet_state_dict = checkpoint else: # extract state_dict for UNet unet_state_dict = {} keys = list(checkpoint.keys()) if controlnet: unet_key = "control_model." else: unet_key = "model.diffusion_model." # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") logger.warning( "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." ) for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) else: if sum(k.startswith("model_ema") for k in keys) > 100: logger.warning( "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" " weights (usually better for inference), please make sure to add the `--extract_ema` flag." ) for key in keys: if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {} new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] if config["class_embed_type"] is None: # No parameters to port ... elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] else: raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") if config["addition_embed_type"] == "text_time": new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] # Relevant to StableDiffusionUpscalePipeline if "num_class_embeds" in config: if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] if not controlnet: new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] # Retrieves the keys for the input blocks only num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) input_blocks = { layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } # Retrieves the keys for the middle blocks only num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) middle_blocks = { layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } # Retrieves the keys for the output blocks only num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) output_blocks = { layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] for layer_id in range(num_output_blocks) } for i in range(1, num_input_blocks): block_id = (i - 1) // (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) resnets = [ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.weight" ) new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.bias" ) paths = renew_resnet_paths(resnets) meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) if len(attentions): paths = renew_attention_paths(attentions) meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) resnet_0 = middle_blocks[0] attentions = middle_blocks[1] resnet_1 = middle_blocks[2] resnet_0_paths = renew_resnet_paths(resnet_0) assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) resnet_1_paths = renew_resnet_paths(resnet_1) assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} assign_to_checkpoint( attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) for i in range(num_output_blocks): block_id = i // (config["layers_per_block"] + 1) layer_in_block_id = i % (config["layers_per_block"] + 1) output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] output_block_list = {} for layer in output_block_layers: layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) if layer_id in output_block_list: output_block_list[layer_id].append(layer_name) else: output_block_list[layer_id] = [layer_name] if len(output_block_list) > 1: resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) output_block_list = {k: sorted(v) for k, v in sorted(output_block_list.items())} if ["conv.bias", "conv.weight"] in output_block_list.values(): index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.weight" ] new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.bias" ] # Clear attentions as they have been attributed above. if len(attentions) == 2: attentions = [] if len(attentions): paths = renew_attention_paths(attentions) meta_path = { "old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) else: resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) for path in resnet_0_paths: old_path = ".".join(["output_blocks", str(i), path["old"]]) new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) new_checkpoint[new_path] = unet_state_dict[old_path] if controlnet: # conditioning embedding orig_index = 0 new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) orig_index += 2 diffusers_index = 0 while diffusers_index < 6: new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) diffusers_index += 1 orig_index += 2 new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) # down blocks for i in range(num_input_blocks): new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") # mid block new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") return new_checkpoint def convert_ldm_vae_checkpoint(checkpoint, config): # extract state dict for VAE vae_state_dict = {} keys = list(checkpoint.keys()) vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" for key in keys: if key.startswith(vae_key): vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) new_checkpoint = {} new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] # Retrieves the keys for the encoder down blocks only num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) down_blocks = { layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) } # Retrieves the keys for the decoder up blocks only num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) up_blocks = { layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) } for i in range(num_down_blocks): resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.weight" ) new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.bias" ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.weight" ] new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.bias" ] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) return new_checkpoint def convert_ldm_bert_checkpoint(checkpoint, config): def _copy_attn_layer(hf_attn_layer, pt_attn_layer): hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias def _copy_linear(hf_linear, pt_linear): hf_linear.weight = pt_linear.weight hf_linear.bias = pt_linear.bias def _copy_layer(hf_layer, pt_layer): # copy layer norms _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) # copy attn _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) # copy MLP pt_mlp = pt_layer[1][1] _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) _copy_linear(hf_layer.fc2, pt_mlp.net[2]) def _copy_layers(hf_layers, pt_layers): for i, hf_layer in enumerate(hf_layers): if i != 0: i += i pt_layer = pt_layers[i : i + 2] _copy_layer(hf_layer, pt_layer) hf_model = LDMBertModel(config).eval() # copy embeds hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight # copy layer norm _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) # copy hidden layers _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) return hf_model def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): if text_encoder is None: config_name = "openai/clip-vit-large-patch14" try: config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'." ) ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): text_model = CLIPTextModel(config) else: text_model = text_encoder keys = list(checkpoint.keys()) text_model_dict = {} remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] for key in keys: for prefix in remove_prefixes: if key.startswith(prefix): text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] if is_accelerate_available(): for param_name, param in text_model_dict.items(): set_module_tensor_to_device(text_model, param_name, "cpu", value=param) else: if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): text_model_dict.pop("text_model.embeddings.position_ids", None) text_model.load_state_dict(text_model_dict) return text_model textenc_conversion_lst = [ ("positional_embedding", "text_model.embeddings.position_embedding.weight"), ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), ("ln_final.weight", "text_model.final_layer_norm.weight"), ("ln_final.bias", "text_model.final_layer_norm.bias"), ("text_projection", "text_projection.weight"), ] textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} textenc_transformer_conversion_lst = [ # (stable-diffusion, HF Diffusers) ("resblocks.", "text_model.encoder.layers."), ("ln_1", "layer_norm1"), ("ln_2", "layer_norm2"), (".c_fc.", ".fc1."), (".c_proj.", ".fc2."), (".attn", ".self_attn"), ("ln_final.", "transformer.text_model.final_layer_norm."), ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), ] protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} textenc_pattern = re.compile("|".join(protected.keys())) def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False): config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) model = PaintByExampleImageEncoder(config) keys = list(checkpoint.keys()) text_model_dict = {} for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] # load clip vision model.model.load_state_dict(text_model_dict) # load mapper keys_mapper = { k[len("cond_stage_model.mapper.res") :]: v for k, v in checkpoint.items() if k.startswith("cond_stage_model.mapper") } MAPPING = { "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], "attn.c_proj": ["attn1.to_out.0"], "ln_1": ["norm1"], "ln_2": ["norm3"], "mlp.c_fc": ["ff.net.0.proj"], "mlp.c_proj": ["ff.net.2"], } mapped_weights = {} for key, value in keys_mapper.items(): prefix = key[: len("blocks.i")] suffix = key.split(prefix)[-1].split(".")[-1] name = key.split(prefix)[-1].split(suffix)[0][1:-1] mapped_names = MAPPING[name] num_splits = len(mapped_names) for i, mapped_name in enumerate(mapped_names): new_name = ".".join([prefix, mapped_name, suffix]) shape = value.shape[0] // num_splits mapped_weights[new_name] = value[i * shape : (i + 1) * shape] model.mapper.load_state_dict(mapped_weights) # load final layer norm model.final_layer_norm.load_state_dict( { "bias": checkpoint["cond_stage_model.final_ln.bias"], "weight": checkpoint["cond_stage_model.final_ln.weight"], } ) # load final proj model.proj_out.load_state_dict( { "bias": checkpoint["proj_out.bias"], "weight": checkpoint["proj_out.weight"], } ) # load uncond vector model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) return model def convert_open_clip_checkpoint( checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, local_files_only=False, **config_kwargs, ): # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") # text_model = CLIPTextModelWithProjection.from_pretrained( # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 # ) try: config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only) except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'." ) ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) keys = list(checkpoint.keys()) keys_to_ignore = [] if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: # make sure to remove all keys > 22 keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] keys_to_ignore += ["cond_stage_model.model.text_projection"] text_model_dict = {} if prefix + "text_projection" in checkpoint: d_model = int(checkpoint[prefix + "text_projection"].shape[0]) else: d_model = 1024 text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") for key in keys: if key in keys_to_ignore: continue if key[len(prefix) :] in textenc_conversion_map: if key.endswith("text_projection"): value = checkpoint[key].T.contiguous() else: value = checkpoint[key] text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value if key.startswith(prefix + "transformer."): new_key = key[len(prefix + "transformer.") :] if new_key.endswith(".in_proj_weight"): new_key = new_key[: -len(".in_proj_weight")] new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] elif new_key.endswith(".in_proj_bias"): new_key = new_key[: -len(".in_proj_bias")] new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] else: new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key] = checkpoint[key] if is_accelerate_available(): for param_name, param in text_model_dict.items(): set_module_tensor_to_device(text_model, param_name, "cpu", value=param) else: if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): text_model_dict.pop("text_model.embeddings.position_ids", None) text_model.load_state_dict(text_model_dict) return text_model def stable_unclip_image_encoder(original_config, local_files_only=False): """ Returns the image processor and clip image encoder for the img2img unclip pipeline. We currently know of two types of stable unclip models which separately use the clip and the openclip image encoders. """ image_embedder_config = original_config["model"]["params"]["embedder_config"] sd_clip_image_embedder_class = image_embedder_config["target"] sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] if sd_clip_image_embedder_class == "ClipImageEmbedder": clip_model_name = image_embedder_config.params.model if clip_model_name == "ViT-L/14": feature_extractor = CLIPImageProcessor() image_encoder = CLIPVisionModelWithProjection.from_pretrained( "openai/clip-vit-large-patch14", local_files_only=local_files_only ) else: raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": feature_extractor = CLIPImageProcessor() image_encoder = CLIPVisionModelWithProjection.from_pretrained( "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only ) else: raise NotImplementedError( f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" ) return feature_extractor, image_encoder def stable_unclip_image_noising_components( original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None ): """ Returns the noising components for the img2img and txt2img unclip pipelines. Converts the stability noise augmentor into 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats 2. a `DDPMScheduler` for holding the noise schedule If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. """ noise_aug_config = original_config["model"]["params"]["noise_aug_config"] noise_aug_class = noise_aug_config["target"] noise_aug_class = noise_aug_class.split(".")[-1] if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": noise_aug_config = noise_aug_config.params embedding_dim = noise_aug_config.timestep_dim max_noise_level = noise_aug_config.noise_schedule_config.timesteps beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) if "clip_stats_path" in noise_aug_config: if clip_stats_path is None: raise ValueError("This stable unclip config requires a `clip_stats_path`") clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) clip_mean = clip_mean[None, :] clip_std = clip_std[None, :] clip_stats_state_dict = { "mean": clip_mean, "std": clip_std, } image_normalizer.load_state_dict(clip_stats_state_dict) else: raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") return image_normalizer, image_noising_scheduler def convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema, use_linear_projection=None, cross_attention_dim=None, ): ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config.pop("sample_size") if use_linear_projection is not None: ctrlnet_config["use_linear_projection"] = use_linear_projection if cross_attention_dim is not None: ctrlnet_config["cross_attention_dim"] = cross_attention_dim ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): controlnet = ControlNetModel(**ctrlnet_config) # Some controlnet ckpt files are distributed independently from the rest of the # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ if "time_embed.0.weight" in checkpoint: skip_extract_state_dict = True else: skip_extract_state_dict = False converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True, skip_extract_state_dict=skip_extract_state_dict, ) if is_accelerate_available(): for param_name, param in converted_ctrl_checkpoint.items(): set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) else: controlnet.load_state_dict(converted_ctrl_checkpoint) return controlnet def download_from_original_stable_diffusion_ckpt( checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]], original_config_file: str = None, image_size: Optional[int] = None, prediction_type: str = None, model_type: str = None, extract_ema: bool = False, scheduler_type: str = "pndm", num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, device: str = None, from_safetensors: bool = False, stable_unclip: Optional[str] = None, stable_unclip_prior: Optional[str] = None, clip_stats_path: Optional[str] = None, controlnet: Optional[bool] = None, adapter: Optional[bool] = None, load_safety_checker: bool = True, safety_checker: Optional[StableDiffusionSafetyChecker] = None, feature_extractor: Optional[AutoFeatureExtractor] = None, pipeline_class: DiffusionPipeline = None, local_files_only=False, vae_path=None, vae=None, text_encoder=None, text_encoder_2=None, tokenizer=None, tokenizer_2=None, config_files=None, ) -> DiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file. Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is recommended that you override the default values and/or supply an `original_config_file` wherever possible. Args: checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict. original_config_file (`str`): Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models. image_size (`int`, *optional*, defaults to 512): The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 Base. Use 768 for Stable Diffusion v2. prediction_type (`str`, *optional*): The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. num_in_channels (`int`, *optional*, defaults to None): The number of input channels. If `None`, it will be automatically inferred. scheduler_type (`str`, *optional*, defaults to 'pndm'): Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", "ddim"]`. model_type (`str`, *optional*, defaults to `None`): The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. is_img2img (`bool`, *optional*, defaults to `False`): Whether the model should be loaded as an img2img pipeline. extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning. upcast_attention (`bool`, *optional*, defaults to `None`): Whether the attention computation should always be upcasted. This is necessary when running stable diffusion 2.1. device (`str`, *optional*, defaults to `None`): The device to use. Pass `None` to determine automatically. from_safetensors (`str`, *optional*, defaults to `False`): If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. load_safety_checker (`bool`, *optional*, defaults to `True`): Whether to load the safety checker or not. Defaults to `True`. safety_checker (`StableDiffusionSafetyChecker`, *optional*, defaults to `None`): Safety checker to use. If this parameter is `None`, the function will load a new instance of [StableDiffusionSafetyChecker] by itself, if needed. feature_extractor (`AutoFeatureExtractor`, *optional*, defaults to `None`): Feature extractor to use. If this parameter is `None`, the function will load a new instance of [AutoFeatureExtractor] by itself, if needed. pipeline_class (`str`, *optional*, defaults to `None`): The pipeline class to use. Pass `None` to determine automatically. local_files_only (`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). vae (`AutoencoderKL`, *optional*, defaults to `None`): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): An instance of [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if needed. config_files (`Dict[str, str]`, *optional*, defaults to `None`): A dictionary mapping from config file names to their contents. If this parameter is `None`, the function will load the config files by itself, if needed. Valid keys are: - `v1`: Config file for Stable Diffusion v1 - `v2`: Config file for Stable Diffusion v2 - `xl`: Config file for Stable Diffusion XL - `xl_refiner`: Config file for Stable Diffusion XL Refiner return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. """ # import pipelines here to avoid circular import error when using from_single_file method from diffusers import ( LDMTextToImagePipeline, PaintByExamplePipeline, StableDiffusionControlNetPipeline, StableDiffusionInpaintPipeline, StableDiffusionPipeline, StableDiffusionUpscalePipeline, StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) if prediction_type == "v-prediction": prediction_type = "v_prediction" if isinstance(checkpoint_path_or_dict, str): if from_safetensors: from safetensors.torch import load_file as safe_load checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") else: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) else: checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) elif isinstance(checkpoint_path_or_dict, dict): checkpoint = checkpoint_path_or_dict # Sometimes models don't have the global_step item if "global_step" in checkpoint: global_step = checkpoint["global_step"] else: logger.debug("global_step key not found in model") global_step = None # NOTE: this while loop isn't great but this controlnet checkpoint has one additional # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] if original_config_file is None: key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" is_upscale = pipeline_class == StableDiffusionUpscalePipeline config_url = None # model_type = "v1" if config_files is not None and "v1" in config_files: original_config_file = config_files["v1"] else: config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: # model_type = "v2" if config_files is not None and "v2" in config_files: original_config_file = config_files["v2"] else: config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" if global_step == 110000: # v2.1 needs to upcast attention upcast_attention = True elif key_name_sd_xl_base in checkpoint: # only base xl has two text embedders if config_files is not None and "xl" in config_files: original_config_file = config_files["xl"] else: config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" elif key_name_sd_xl_refiner in checkpoint: # only refiner xl has embedder and one text embedders if config_files is not None and "xl_refiner" in config_files: original_config_file = config_files["xl_refiner"] else: config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" if is_upscale: config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml" if config_url is not None: original_config_file = BytesIO(requests.get(config_url).content) else: with open(original_config_file, "r") as f: original_config_file = f.read() else: with open(original_config_file, "r") as f: original_config_file = f.read() original_config = yaml.safe_load(original_config_file) # Convert the text model. if ( model_type is None and "cond_stage_config" in original_config["model"]["params"] and original_config["model"]["params"]["cond_stage_config"] is not None ): model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") elif model_type is None and original_config["model"]["params"]["network_config"] is not None: if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048: model_type = "SDXL" else: model_type = "SDXL-Refiner" if image_size is None: image_size = 1024 if pipeline_class is None: # Check if we have a SDXL or SD model and initialize default pipeline if model_type not in ["SDXL", "SDXL-Refiner"]: pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline else: pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline if num_in_channels is None and pipeline_class in [ StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLControlNetInpaintPipeline, ]: num_in_channels = 9 if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline: num_in_channels = 7 elif num_in_channels is None: num_in_channels = 4 if "unet_config" in original_config["model"]["params"]: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels elif "network_config" in original_config["model"]["params"]: original_config["model"]["params"]["network_config"]["params"]["in_channels"] = num_in_channels if ( "parameterization" in original_config["model"]["params"] and original_config["model"]["params"]["parameterization"] == "v" ): if prediction_type is None: # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here prediction_type = "epsilon" if global_step == 875000 else "v_prediction" if image_size is None: # NOTE: For stable diffusion 2 base one has to pass `image_size==512` # as it relies on a brittle global step parameter here image_size = 512 if global_step == 875000 else 768 else: if prediction_type is None: prediction_type = "epsilon" if image_size is None: image_size = 512 if controlnet is None and "control_stage_config" in original_config["model"]["params"]: path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" controlnet = convert_controlnet_checkpoint( checkpoint, original_config, path, image_size, upcast_attention, extract_ema ) if "timesteps" in original_config["model"]["params"]: num_train_timesteps = original_config["model"]["params"]["timesteps"] else: num_train_timesteps = 1000 if model_type in ["SDXL", "SDXL-Refiner"]: scheduler_dict = { "beta_schedule": "scaled_linear", "beta_start": 0.00085, "beta_end": 0.012, "interpolation_type": "linear", "num_train_timesteps": num_train_timesteps, "prediction_type": "epsilon", "sample_max_value": 1.0, "set_alpha_to_one": False, "skip_prk_steps": True, "steps_offset": 1, "timestep_spacing": "leading", } scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) scheduler_type = "euler" else: if "linear_start" in original_config["model"]["params"]: beta_start = original_config["model"]["params"]["linear_start"] else: beta_start = 0.02 if "linear_end" in original_config["model"]["params"]: beta_end = original_config["model"]["params"]["linear_end"] else: beta_end = 0.085 scheduler = DDIMScheduler( beta_end=beta_end, beta_schedule="scaled_linear", beta_start=beta_start, num_train_timesteps=num_train_timesteps, steps_offset=1, clip_sample=False, set_alpha_to_one=False, prediction_type=prediction_type, ) # make sure scheduler works correctly with DDIM scheduler.register_to_config(clip_sample=False) if scheduler_type == "pndm": config = dict(scheduler.config) config["skip_prk_steps"] = True scheduler = PNDMScheduler.from_config(config) elif scheduler_type == "lms": scheduler = LMSDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "heun": scheduler = HeunDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "euler": scheduler = EulerDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "euler-ancestral": scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "dpm": scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) elif scheduler_type == "ddim": scheduler = scheduler else: raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") if pipeline_class == StableDiffusionUpscalePipeline: image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"] # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config["upcast_attention"] = upcast_attention path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" converted_unet_checkpoint = convert_ldm_unet_checkpoint( checkpoint, unet_config, path=path, extract_ema=extract_ema ) ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): unet = UNet2DConditionModel(**unet_config) if is_accelerate_available(): if model_type not in ["SDXL", "SDXL-Refiner"]: # SBM Delay this. for param_name, param in converted_unet_checkpoint.items(): set_module_tensor_to_device(unet, param_name, "cpu", value=param) else: unet.load_state_dict(converted_unet_checkpoint) # Convert the VAE model. if vae_path is None and vae is None: vae_config = create_vae_diffusers_config(original_config, image_size=image_size) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) if ( "model" in original_config and "params" in original_config["model"] and "scale_factor" in original_config["model"]["params"] ): vae_scaling_factor = original_config["model"]["params"]["scale_factor"] else: vae_scaling_factor = 0.18215 # default SD scaling factor vae_config["scaling_factor"] = vae_scaling_factor ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): vae = AutoencoderKL(**vae_config) if is_accelerate_available(): for param_name, param in converted_vae_checkpoint.items(): set_module_tensor_to_device(vae, param_name, "cpu", value=param) else: vae.load_state_dict(converted_vae_checkpoint) elif vae is None: vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only) if model_type == "FrozenOpenCLIPEmbedder": config_name = "stabilityai/stable-diffusion-2" config_kwargs = {"subfolder": "text_encoder"} if text_encoder is None: text_model = convert_open_clip_checkpoint( checkpoint, config_name, local_files_only=local_files_only, **config_kwargs ) else: text_model = text_encoder try: tokenizer = CLIPTokenizer.from_pretrained( "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only ) except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'." ) if stable_unclip is None: if controlnet: pipe = pipeline_class( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, controlnet=controlnet, safety_checker=safety_checker, feature_extractor=feature_extractor, ) if hasattr(pipe, "requires_safety_checker"): pipe.requires_safety_checker = False elif pipeline_class == StableDiffusionUpscalePipeline: scheduler = DDIMScheduler.from_pretrained( "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler" ) low_res_scheduler = DDPMScheduler.from_pretrained( "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" ) pipe = pipeline_class( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, low_res_scheduler=low_res_scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) else: pipe = pipeline_class( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) if hasattr(pipe, "requires_safety_checker"): pipe.requires_safety_checker = False else: image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( original_config, clip_stats_path=clip_stats_path, device=device ) if stable_unclip == "img2img": feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) pipe = StableUnCLIPImg2ImgPipeline( # image encoding components feature_extractor=feature_extractor, image_encoder=image_encoder, # image noising components image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, # regular denoising components tokenizer=tokenizer, text_encoder=text_model, unet=unet, scheduler=scheduler, # vae vae=vae, ) elif stable_unclip == "txt2img": if stable_unclip_prior is None or stable_unclip_prior == "karlo": karlo_model = "kakaobrain/karlo-v1-alpha" prior = PriorTransformer.from_pretrained( karlo_model, subfolder="prior", local_files_only=local_files_only ) try: prior_tokenizer = CLIPTokenizer.from_pretrained( "openai/clip-vit-large-patch14", local_files_only=local_files_only ) except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." ) prior_text_model = CLIPTextModelWithProjection.from_pretrained( "openai/clip-vit-large-patch14", local_files_only=local_files_only ) prior_scheduler = UnCLIPScheduler.from_pretrained( karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only ) prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) else: raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") pipe = StableUnCLIPPipeline( # prior components prior_tokenizer=prior_tokenizer, prior_text_encoder=prior_text_model, prior=prior, prior_scheduler=prior_scheduler, # image noising components image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, # regular denoising components tokenizer=tokenizer, text_encoder=text_model, unet=unet, scheduler=scheduler, # vae vae=vae, ) else: raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") elif model_type == "PaintByExample": vision_model = convert_paint_by_example_checkpoint(checkpoint) try: tokenizer = CLIPTokenizer.from_pretrained( "openai/clip-vit-large-patch14", local_files_only=local_files_only ) except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." ) try: feature_extractor = AutoFeatureExtractor.from_pretrained( "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only ) except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'." ) pipe = PaintByExamplePipeline( vae=vae, image_encoder=vision_model, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor=feature_extractor, ) elif model_type == "FrozenCLIPEmbedder": text_model = convert_ldm_clip_checkpoint( checkpoint, local_files_only=local_files_only, text_encoder=text_encoder ) try: tokenizer = ( CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) if tokenizer is None else tokenizer ) except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." ) if load_safety_checker: safety_checker = StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only ) feature_extractor = AutoFeatureExtractor.from_pretrained( "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only ) if controlnet: pipe = pipeline_class( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) else: pipe = pipeline_class( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) elif model_type in ["SDXL", "SDXL-Refiner"]: is_refiner = model_type == "SDXL-Refiner" if (is_refiner is False) and (tokenizer is None): try: tokenizer = CLIPTokenizer.from_pretrained( "openai/clip-vit-large-patch14", local_files_only=local_files_only ) except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." ) if (is_refiner is False) and (text_encoder is None): text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) if tokenizer_2 is None: try: tokenizer_2 = CLIPTokenizer.from_pretrained( "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only ) except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'." ) if text_encoder_2 is None: config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" config_kwargs = {"projection_dim": 1280} prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model." text_encoder_2 = convert_open_clip_checkpoint( checkpoint, config_name, prefix=prefix, has_projection=True, local_files_only=local_files_only, **config_kwargs, ) if is_accelerate_available(): # SBM Now move model to cpu. for param_name, param in converted_unet_checkpoint.items(): set_module_tensor_to_device(unet, param_name, "cpu", value=param) if controlnet: pipe = pipeline_class( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, unet=unet, controlnet=controlnet, scheduler=scheduler, force_zeros_for_empty_prompt=True, ) elif adapter: pipe = pipeline_class( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, unet=unet, adapter=adapter, scheduler=scheduler, force_zeros_for_empty_prompt=True, ) else: pipeline_kwargs = { "vae": vae, "text_encoder": text_encoder, "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, "unet": unet, "scheduler": scheduler, } if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or ( pipeline_class == StableDiffusionXLInpaintPipeline ): pipeline_kwargs.update({"requires_aesthetics_score": is_refiner}) if is_refiner: pipeline_kwargs.update({"force_zeros_for_empty_prompt": False}) pipe = pipeline_class(**pipeline_kwargs) else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only) pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) return pipe def download_controlnet_from_original_ckpt( checkpoint_path: str, original_config_file: str, image_size: int = 512, extract_ema: bool = False, num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, device: str = None, from_safetensors: bool = False, use_linear_projection: Optional[bool] = None, cross_attention_dim: Optional[bool] = None, ) -> DiffusionPipeline: if from_safetensors: from safetensors import safe_open checkpoint = {} with safe_open(checkpoint_path, framework="pt", device="cpu") as f: for key in f.keys(): checkpoint[key] = f.get_tensor(key) else: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) # NOTE: this while loop isn't great but this controlnet checkpoint has one additional # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] with open(original_config_file, "r") as f: original_config_file = f.read() original_config = yaml.safe_load(original_config_file) if num_in_channels is not None: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels if "control_stage_config" not in original_config["model"]["params"]: raise ValueError("`control_stage_config` not present in original config") controlnet = convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema, use_linear_projection=use_linear_projection, cross_attention_dim=cross_attention_dim, ) return controlnet ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from packaging import version from PIL import Image from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import deprecate, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from .pipeline_output import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> from diffusers import FlaxStableDiffusionPipeline >>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jax.numpy.bfloat16 ... ) >>> prompt = "a photo of an astronaut riding a horse on mars" >>> prng_seed = jax.random.PRNGKey(0) >>> num_inference_steps = 50 >>> num_samples = jax.device_count() >>> prompt = num_samples * [prompt] >>> prompt_ids = pipeline.prepare_inputs(prompt) # shard inputs and rng >>> params = replicate(params) >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) >>> prompt_ids = shard(prompt_ids) >>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) ``` """ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): r""" Flax-based pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.FlaxCLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`FlaxUNet2DConditionModel`]): A `FlaxUNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def _generate( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int, height: int, width: int, guidance_scale: float, latents: Optional[jnp.ndarray] = None, neg_prompt_ids: Optional[jnp.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) # Ensure model output will be `float32` before going into the scheduler guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) latents_shape = ( batch_size, self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") def loop_body(step, args): latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, guidance_scale: Union[float, jnp.ndarray] = 7.5, latents: jnp.ndarray = None, neg_prompt_ids: jnp.ndarray = None, return_dict: bool = True, jit: bool = False, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. latents (`jnp.ndarray`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents array is generated by sampling using the supplied random `generator`. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] if jit: images = _p_generate( self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) else: images = self._generate( prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.asarray(images).copy() # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i, 0] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0), static_broadcasted_argnums=(0, 4, 5, 6), ) def _p_generate( pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ): return pipe._generate( prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from PIL import Image from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from .pipeline_output import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> import jax.numpy as jnp >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> from diffusers import FlaxStableDiffusionImg2ImgPipeline >>> def create_key(seed=0): ... return jax.random.PRNGKey(seed) >>> rng = create_key(0) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_img = Image.open(BytesIO(response.content)).convert("RGB") >>> init_img = init_img.resize((768, 512)) >>> prompts = "A fantasy landscape, trending on artstation" >>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", ... revision="flax", ... dtype=jnp.bfloat16, ... ) >>> num_samples = jax.device_count() >>> rng = jax.random.split(rng, jax.device_count()) >>> prompt_ids, processed_image = pipeline.prepare_inputs( ... prompt=[prompts] * num_samples, image=[init_img] * num_samples ... ) >>> p_params = replicate(params) >>> prompt_ids = shard(prompt_ids) >>> processed_image = shard(processed_image) >>> output = pipeline( ... prompt_ids=prompt_ids, ... image=processed_image, ... params=p_params, ... prng_seed=rng, ... strength=0.75, ... num_inference_steps=50, ... jit=True, ... height=512, ... width=768, ... ).images >>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) ``` """ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): r""" Flax-based pipeline for text-guided image-to-image generation using Stable Diffusion. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.FlaxCLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`FlaxUNet2DConditionModel`]): A `FlaxUNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if not isinstance(image, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(image, Image.Image): image = [image] processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids, processed_images def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def get_timestep_start(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) return t_start def _generate( self, prompt_ids: jnp.ndarray, image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, start_timestep: int, num_inference_steps: int, height: int, width: int, guidance_scale: float, noise: Optional[jnp.ndarray] = None, neg_prompt_ids: Optional[jnp.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) latents_shape = ( batch_size, self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if noise is None: noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if noise.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}") # Create init_latents init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) init_latents = self.vae.config.scaling_factor * init_latents def loop_body(step, args): latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape ) latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size) latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(start_timestep, num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.ndarray, image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, strength: float = 0.8, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, guidance_scale: Union[float, jnp.ndarray] = 7.5, noise: jnp.ndarray = None, neg_prompt_ids: jnp.ndarray = None, return_dict: bool = True, jit: bool = False, ): r""" The call function to the pipeline for generation. Args: prompt_ids (`jnp.ndarray`): The prompt or prompts to guide image generation. image (`jnp.ndarray`): Array representing an image batch to be used as the starting point. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights. prng_seed (`jax.Array` or `jax.Array`): Array containing random number generator key. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. noise (`jnp.ndarray`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. The array is generated by sampling using the supplied random `generator`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] start_timestep = self.get_timestep_start(num_inference_steps, strength) if jit: images = _p_generate( self, prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ) else: images = self._generate( prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.asarray(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0), static_broadcasted_argnums=(0, 5, 6, 7, 8), ) def _p_generate( pipe, prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ): return pipe._generate( prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) def preprocess(image, dtype): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) return 2.0 * image - 1.0 ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from packaging import version from PIL import Image from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from .pipeline_output import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> import PIL >>> import requests >>> from io import BytesIO >>> from diffusers import FlaxStableDiffusionInpaintPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" >>> init_image = download_image(img_url).resize((512, 512)) >>> mask_image = download_image(mask_url).resize((512, 512)) >>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained( ... "xvjiarui/stable-diffusion-2-inpainting" ... ) >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" >>> prng_seed = jax.random.PRNGKey(0) >>> num_inference_steps = 50 >>> num_samples = jax.device_count() >>> prompt = num_samples * [prompt] >>> init_image = num_samples * [init_image] >>> mask_image = num_samples * [mask_image] >>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs( ... prompt, init_image, mask_image ... ) # shard inputs and rng >>> params = replicate(params) >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) >>> prompt_ids = shard(prompt_ids) >>> processed_masked_images = shard(processed_masked_images) >>> processed_masks = shard(processed_masks) >>> images = pipeline( ... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True ... ).images >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) ``` """ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): r""" Flax-based pipeline for text-guided image inpainting using Stable Diffusion. 🧪 This is an experimental feature! This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.FlaxCLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`FlaxUNet2DConditionModel`]): A `FlaxUNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs( self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]], mask: Union[Image.Image, List[Image.Image]], ): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if not isinstance(image, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(image, Image.Image): image = [image] if not isinstance(mask, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(mask, Image.Image): mask = [mask] processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image]) processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask]) # processed_masks[processed_masks < 0.5] = 0 processed_masks = processed_masks.at[processed_masks < 0.5].set(0) # processed_masks[processed_masks >= 0.5] = 1 processed_masks = processed_masks.at[processed_masks >= 0.5].set(1) processed_masked_images = processed_images * (processed_masks < 0.5) text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids, processed_masked_images, processed_masks def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def _generate( self, prompt_ids: jnp.ndarray, mask: jnp.ndarray, masked_image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int, height: int, width: int, guidance_scale: float, latents: Optional[jnp.ndarray] = None, neg_prompt_ids: Optional[jnp.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) latents_shape = ( batch_size, self.vae.config.latent_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") prng_seed, mask_prng_seed = jax.random.split(prng_seed) masked_image_latent_dist = self.vae.apply( {"params": params["vae"]}, masked_image, method=self.vae.encode ).latent_dist masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2)) masked_image_latents = self.vae.config.scaling_factor * masked_image_latents del mask_prng_seed mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest") # 8. Check that sizes of mask, masked image and latents match num_channels_latents = self.vae.config.latent_channels num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) def loop_body(step, args): latents, mask, masked_image_latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) mask_input = jnp.concatenate([mask] * 2) masked_image_latents_input = jnp.concatenate([masked_image_latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) # concat latents, mask, masked_image_latents in the channel dimension latents_input = jnp.concatenate([latents_input, mask_input, masked_image_latents_input], axis=1) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, mask, masked_image_latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(num_inference_steps): latents, mask, masked_image_latents, scheduler_state = loop_body( i, (latents, mask, masked_image_latents, scheduler_state) ) else: latents, _, _, _ = jax.lax.fori_loop( 0, num_inference_steps, loop_body, (latents, mask, masked_image_latents, scheduler_state) ) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.ndarray, mask: jnp.ndarray, masked_image: jnp.ndarray, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, guidance_scale: Union[float, jnp.ndarray] = 7.5, latents: jnp.ndarray = None, neg_prompt_ids: jnp.ndarray = None, return_dict: bool = True, jit: bool = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide image generation. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. latents (`jnp.ndarray`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents array is generated by sampling using the supplied random `generator`. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor masked_image = jax.image.resize(masked_image, (*masked_image.shape[:-2], height, width), method="bicubic") mask = jax.image.resize(mask, (*mask.shape[:-2], height, width), method="nearest") if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] if jit: images = _p_generate( self, prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) else: images = self._generate( prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.asarray(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0), static_broadcasted_argnums=(0, 6, 7, 8), ) def _p_generate( pipe, prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ): return pipe._generate( prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) def preprocess_image(image, dtype): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) return 2.0 * image - 1.0 def preprocess_mask(mask, dtype): w, h = mask.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = mask.resize((w, h)) mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 mask = jnp.expand_dims(mask, axis=(0, 1)) return mask ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import numpy as np import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) class OnnxStableDiffusionPipeline(DiffusionPipeline): vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] _is_onnx = True def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def check_inputs( self, prompt: Union[str, List[str]], height: Optional[int], width: Optional[int], callback_steps: int, negative_prompt: Optional[str] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = 512, width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.Tensor`): `Image`, or tensor representing an image batch which will be upscaled. * num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): One or a list of [numpy generator(s)](TODO) to make generation deterministic. latents (`np.ndarray`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if generator is None: generator = np.random # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # get the initial random noise unless the user supplied it latents_dtype = prompt_embeds.dtype latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) elif latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # set timesteps self.scheduler.set_timesteps(num_inference_steps) latents = latents * np.float64(self.scheduler.init_noise_sigma) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) latent_model_input = latent_model_input.cpu().numpy() # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) noise_pred = noise_pred[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ) latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, ): deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) super().__init__( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] _is_onnx = True def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def check_inputs( self, prompt: Union[str, List[str]], callback_steps: int, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def __call__( self, prompt: Union[str, List[str]], image: Union[np.ndarray, PIL.Image.Image] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`np.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if generator is None: generator = np.random # set timesteps self.scheduler.set_timesteps(num_inference_steps) image = preprocess(image).cpu().numpy() # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = prompt_embeds.dtype image = image.astype(latents_dtype) # encode the init image into latents and scale the latents init_latents = self.vae_encoder(sample=image)[0] init_latents = 0.18215 * init_latents if isinstance(prompt, str): prompt = [prompt] if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = len(prompt) // init_latents.shape[0] init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0) elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." ) else: init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps.numpy()[-init_timestep] timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) # add noise to latents using the timesteps noise = generator.randn(*init_latents.shape).astype(latents_dtype) init_latents = self.scheduler.add_noise( torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) ) init_latents = init_latents.numpy() # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:].numpy() timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) latent_model_input = latent_model_input.cpu().numpy() # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ 0 ] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ) latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) # safety_checker does not support batched inputs yet images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name NUM_UNET_INPUT_CHANNELS = 9 NUM_LATENT_CHANNELS = 4 def prepare_mask_and_masked_image(image, mask, latents_shape): image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8))) image = image[None].transpose(0, 3, 1, 2) image = image.astype(np.float32) / 127.5 - 1.0 image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8))) masked_image = image * (image_mask < 127.5) mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"]) mask = np.array(mask.convert("L")) mask = mask.astype(np.float32) / 255.0 mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 return mask, masked_image class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] _is_onnx = True def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs def check_inputs( self, prompt: Union[str, List[str]], height: Optional[int], width: Optional[int], callback_steps: int, negative_prompt: Optional[str] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], image: PIL.Image.Image, mask_image: PIL.Image.Image, height: Optional[int] = 512, width: Optional[int] = 512, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. latents (`np.ndarray`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if generator is None: generator = np.random # set timesteps self.scheduler.set_timesteps(num_inference_steps) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) num_channels_latents = NUM_LATENT_CHANNELS latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) latents_dtype = prompt_embeds.dtype if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # prepare mask and masked_image mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:]) mask = mask.astype(latents.dtype) masked_image = masked_image.astype(latents.dtype) masked_image_latents = self.vae_encoder(sample=masked_image)[0] masked_image_latents = 0.18215 * masked_image_latents # duplicate mask and masked_image_latents for each generation per prompt mask = mask.repeat(batch_size * num_images_per_prompt, 0) masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0) mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] unet_input_channels = NUM_UNET_INPUT_CHANNELS if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels: raise ValueError( "Incorrect configuration settings! The config of `pipeline.unet` expects" f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) # set timesteps self.scheduler.set_timesteps(num_inference_steps) # scale the initial noise by the standard deviation required by the scheduler latents = latents * np.float64(self.scheduler.init_noise_sigma) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latnets in the channel dimension latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) latent_model_input = latent_model_input.cpu().numpy() latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ 0 ] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ) latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) # safety_checker does not support batched inputs yet images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) def preprocess(image): if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 32 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline): vae: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel low_res_scheduler: DDPMScheduler scheduler: KarrasDiffusionSchedulers safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] _is_onnx = True def __init__( self, vae: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: Any, unet: OnnxRuntimeModel, low_res_scheduler: DDPMScheduler, scheduler: KarrasDiffusionSchedulers, safety_checker: Optional[OnnxRuntimeModel] = None, feature_extractor: Optional[CLIPImageProcessor] = None, max_noise_level: int = 350, num_latent_channels=4, num_unet_input_channels=7, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, low_res_scheduler=low_res_scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config( max_noise_level=max_noise_level, num_latent_channels=num_latent_channels, num_unet_input_channels=num_unet_input_channels, ) def check_inputs( self, prompt: Union[str, List[str]], image, noise_level, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, np.ndarray) and not isinstance(image, list) ): raise ValueError( f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" ) # verify batch size of prompt and image are same if image is a list or tensor or numpy array if isinstance(image, (list, np.ndarray)): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if isinstance(image, list): image_batch_size = len(image) else: image_batch_size = image.shape[0] if batch_size != image_batch_size: raise ValueError( f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." " Please make sure that passed `prompt` matches the batch size of `image`." ) # check noise level if noise_level > self.config.max_noise_level: raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): shape = (batch_size, num_channels_latents, height, width) if latents is None: latents = generator.randn(*shape).astype(dtype) elif latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") return latents def decode_latents(self, latents): latents = 1 / 0.08333 * latents image = self.vae(latent_sample=latents)[0] image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) return image def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def __call__( self, prompt: Union[str, List[str]], image: Union[np.ndarray, PIL.Image.Image, List[PIL.Image.Image]], num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[np.random.RandomState, List[np.random.RandomState]]] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: Optional[int] = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`np.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. noise_level (`float`, defaults to 0.2): Deteremines the amount of noise to add to the initial image before performing upscaling. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs( prompt, image, noise_level, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if generator is None: generator = np.random # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = prompt_embeds.dtype image = preprocess(image).cpu().numpy() height, width = image.shape[2:] latents = self.prepare_latents( batch_size * num_images_per_prompt, self.config.num_latent_channels, height, width, latents_dtype, generator, ) image = image.astype(latents_dtype) self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps # Scale the initial noise by the standard deviation required by the scheduler latents = latents * np.float64(self.scheduler.init_noise_sigma) # 5. Add noise to image noise_level = np.array([noise_level]).astype(np.int64) noise = generator.randn(*image.shape).astype(latents_dtype) image = self.low_res_scheduler.add_noise( torch.from_numpy(image), torch.from_numpy(noise), torch.from_numpy(noise_level) ) image = image.numpy() batch_multiplier = 2 if do_classifier_free_guidance else 1 image = np.concatenate([image] * batch_multiplier * num_images_per_prompt) noise_level = np.concatenate([noise_level] * image.shape[0]) # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] if self.config.num_latent_channels + num_channels_image != self.config.num_unet_input_channels: raise ValueError( "Incorrect configuration settings! The config of `pipeline.unet` expects" f" {self.config.num_unet_input_channels} but received `num_channels_latents`: {self.config.num_latent_channels} +" f" `num_channels_image`: {num_channels_image} " f" = {self.config.num_latent_channels + num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = np.concatenate([latent_model_input, image], axis=1) # timestep to tensor timestep = np.array([t], dtype=timestep_dtype) # predict the noise residual noise_pred = self.unet( sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, class_labels=noise_level, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ).prev_sample latents = latents.numpy() # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 10. Post-processing image = self.decode_latents(latents) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL.Image from ...utils import BaseOutput, is_flax_available @dataclass class StableDiffusionPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. nsfw_content_detected (`List[bool]`) List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] if is_flax_available(): import flax @flax.struct.dataclass class FlaxStableDiffusionPipelineOutput(BaseOutput): """ Output class for Flax-based Stable Diffusion pipelines. Args: images (`np.ndarray`): Denoised images of array shape of `(batch_size, height, width, num_channels)`. nsfw_content_detected (`List[bool]`): List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or `None` if safety checking could not be performed. """ images: np.ndarray nsfw_content_detected: List[bool] ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPipeline >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # to deal with lora scaling and other possible forward hooks # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) else None ) # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from packaging import version from transformers import CLIPTextModel, CLIPTokenizer, DPTForDepthEstimation, DPTImageProcessor from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin): r""" Pipeline for text-guided depth-based image-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->unet->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "depth_mask"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, depth_estimator: DPTForDepthEstimation, feature_extractor: DPTImageProcessor, ): super().__init__() is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, depth_estimator=depth_estimator, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device): if isinstance(image, PIL.Image.Image): image = [image] else: image = list(image) if isinstance(image[0], PIL.Image.Image): width, height = image[0].size elif isinstance(image[0], np.ndarray): width, height = image[0].shape[:-1] else: height, width = image[0].shape[-2:] if depth_map is None: pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device=device, dtype=dtype) # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16. # So we use `torch.autocast` here for half precision inference. if torch.backends.mps.is_available(): autocast_ctx = contextlib.nullcontext() logger.warning( "The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16, but autocast is not yet supported on MPS." ) else: autocast_ctx = torch.autocast(device.type, dtype=dtype) with autocast_ctx: depth_map = self.depth_estimator(pixel_values).predicted_depth else: depth_map = depth_map.to(device=device, dtype=dtype) depth_map = torch.nn.functional.interpolate( depth_map.unsqueeze(1), size=(height // self.vae_scale_factor, width // self.vae_scale_factor), mode="bicubic", align_corners=False, ) depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0 depth_map = depth_map.to(dtype) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if depth_map.shape[0] < batch_size: repeat_by = batch_size // depth_map.shape[0] depth_map = depth_map.repeat(repeat_by, 1, 1, 1) depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map return depth_map @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, depth_map: Optional[torch.Tensor] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image` or tensor representing an image batch to be used as the starting point. Can accept image latents as `image` only if `depth_map` is not `None`. depth_map (`torch.Tensor`, *optional*): Depth prediction to be used as additional conditioning for the image generation process. If not defined, it automatically predicts the depth with `self.depth_estimator`. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: ```py >>> import torch >>> import requests >>> from PIL import Image >>> from diffusers import StableDiffusionDepth2ImgPipeline >>> pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-depth", ... torch_dtype=torch.float16, ... ) >>> pipe.to("cuda") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> init_image = Image.open(requests.get(url, stream=True).raw) >>> prompt = "two tigers" >>> n_prompt = "bad, deformed, ugly, bad anotomy" >>> image = pipe(prompt=prompt, image=init_image, negative_prompt=n_prompt, strength=0.7).images[0] ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) # 1. Check inputs self.check_inputs( prompt, strength, callback_steps, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare depth mask depth_mask = self.prepare_depth_map( image, depth_map, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, prompt_embeds.dtype, device, ) # 5. Preprocess image image = self.image_processor.preprocess(image) # 6. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 7. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=self.cross_attention_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) depth_mask = callback_outputs.pop("depth_mask", depth_mask) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import PIL.Image import torch from packaging import version from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMixin): r""" Pipeline to generate image variations from an input image using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ # TODO: feature_extractor is required to encode images (if they are in PIL format), # we should give a descriptive message if the pipeline doesn't have one. _optional_components = ["safety_checker"] model_cpu_offload_seq = "image_encoder->unet->vae" _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, image_encoder: CLIPVisionModelWithProjection, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.unsqueeze(1) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeddings) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, image, height, width, callback_steps): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, ): r""" The call function to the pipeline for generation. Args: image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`): Image or images to guide image generation. If you provide a tensor, it needs to be compatible with [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. Examples: ```py from diffusers import StableDiffusionImageVariationPipeline from PIL import Image from io import BytesIO import requests pipe = StableDiffusionImageVariationPipeline.from_pretrained( "lambdalabs/sd-image-variations-diffusers", revision="v2.0" ) pipe = pipe.to("cuda") url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200" response = requests.get(url) image = Image.open(BytesIO(response.content)).convert("RGB") out = pipe(image, num_images_per_prompt=3, guidance_scale=15) out["images"][0].save("result.jpg") ``` """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width, callback_steps) # 2. Define call parameters if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): batch_size = len(image) else: batch_size = image.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input image image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) self.maybe_free_model_hooks() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import StableDiffusionImg2ImgPipeline >>> device = "cuda" >>> model_id_or_path = "runwayml/stable-diffusion-v1-5" >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") >>> init_image = init_image.resize((768, 512)) >>> prompt = "A fantasy landscape, trending on artstation" >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images >>> images[0].save("fantasy_landscape.png") ``` """ def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionImg2ImgPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, ): r""" Pipeline for text-guided image-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. set timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import PIL.Image import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionInpaintPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, ): r""" Pipeline for text-guided image inpainting using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "mask", "masked_image_latents"] def __init__( self, vae: Union[AutoencoderKL, AsymmetricAutoencoderKL], text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration" " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" " Hub, it would be very nice if you could open a Pull request for the" " `scheduler/scheduler_config.json` file" ) deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["skip_prk_steps"] = True scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 if unet.config.in_channels != 9: logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, mask_image, height, width, strength, callback_steps, output_type, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, padding_mask_crop=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( f"The mask image should be a PIL image when inpainting mask crop, but is of type" f" {type(mask_image)}." ) if output_type != "pil": raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None, timestep=None, is_strength_max=True, return_noise=False, return_image_latents=False, ): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if (image is None or timestep is None) and not is_strength_max: raise ValueError( "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." "However, either the image or the noise timestep has not been provided." ) if return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) if image.shape[1] == 4: image_latents = image else: image_latents = self._encode_vae_image(image=image, generator=generator) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents else: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma outputs = (latents,) if return_noise: outputs += (noise,) if return_image_latents: outputs += (image_latents,) return outputs def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = retrieve_latents(self.vae.encode(image), generator=generator) image_latents = self.vae.config.scaling_factor * image_latents return image_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) if masked_image.shape[1] == 4: masked_image_latents = masked_image else: masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, mask_image: PipelineImageInput = None, masked_image_latents: torch.Tensor = None, height: Optional[int] = None, width: Optional[int] = None, padding_mask_crop: Optional[int] = None, strength: float = 1.0, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 1)`, or `(H, W)`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. padding_mask_crop (`int`, *optional*, defaults to `None`): The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large and contain information irrelevant for inpainting, such as background. strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionInpaintPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" >>> init_image = download_image(img_url).resize((512, 512)) >>> mask_image = download_image(mask_url).resize((512, 512)) >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs self.check_inputs( prompt, image, mask_image, height, width, strength, callback_steps, output_type, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, padding_mask_crop, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. set timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device ) # check that number of inference steps is not < 1 - as this doesn't make sense if num_inference_steps < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 # 5. Preprocess mask and image if padding_mask_crop is not None: crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) resize_mode = "fill" else: crops_coords = None resize_mode = "default" original_image = image init_image = self.image_processor.preprocess( image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode ) init_image = init_image.to(dtype=torch.float32) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, return_noise=True, return_image_latents=return_image_latents, ) if return_image_latents: latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs # 7. Prepare mask latent variables mask_condition = self.mask_processor.preprocess( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) if masked_image_latents is None: masked_image = init_image * (mask_condition < 0.5) else: masked_image = masked_image_latents mask, masked_image_latents = self.prepare_mask_latents( mask_condition, masked_image, batch_size * num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, self.do_classifier_free_guidance, ) # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: raise ValueError( f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 9.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 10. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if num_channels_unet == 9: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: init_latents_proper = image_latents if self.do_classifier_free_guidance: init_mask, _ = mask.chunk(2) else: init_mask = mask if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) mask = callback_outputs.pop("mask", mask) masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": condition_kwargs = {} if isinstance(self.vae, AsymmetricAutoencoderKL): init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) init_image_condition = init_image.clone() init_image = self._encode_vae_image(init_image, generator=generator) mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) condition_kwargs = {"image": init_image_condition, "mask": mask_condition} image = self.vae.decode( latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs )[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if padding_mask_crop is not None: image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py ================================================ # Copyright 2024 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") class StableDiffusionInstructPix2PixPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, ): r""" Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, num_inference_steps: int = 100, guidance_scale: float = 7.5, image_guidance_scale: float = 1.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], cross_attention_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept image latents as `image`, but if passing latents directly it is not encoded again. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. image_guidance_scale (`float`, *optional*, defaults to 1.5): Push the generated image towards the initial `image`. Image guidance scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a value of at least `1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Examples: ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionInstructPix2PixPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" >>> image = download_image(img_url).resize((512, 512)) >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "make the mountains snowy" >>> image = pipe(prompt=prompt, image=image).images[0] ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Check inputs self.check_inputs( prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._image_guidance_scale = image_guidance_scale device = self._execution_device if image is None: raise ValueError("`image` input cannot be undefined.") # 1. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 2. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 3. Preprocess image image = self.image_processor.preprocess(image) # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare Image latents image_latents = self.prepare_image_latents( image, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, self.do_classifier_free_guidance, ) height, width = image_latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Check that shapes of latents and image match the UNet channels num_channels_image = image_latents.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8.1 Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Expand the latents if we are doing classifier free guidance. # The latents are expanded 3 times because for pix2pix the guidance\ # is applied for both the text and the input image. latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents # concat latents, image_latents in the channel dimension scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) # predict the noise residual noise_pred = self.unet( scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) noise_pred = ( noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_image) + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond) ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) image_latents = callback_outputs.pop("image_latents", image_latents) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_ prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype else: prompt_embeds_dtype = self.unet.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) single_negative_image_embeds = torch.stack( [single_negative_image_embeds] * num_images_per_prompt, dim=0 ) if do_classifier_free_guidance: single_image_embeds = torch.cat( [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds] ) single_image_embeds = single_image_embeds.to(device) image_embeds.append(single_image_embeds) else: repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: ( single_image_embeds, single_negative_image_embeds, single_negative_image_embeds, ) = single_image_embeds.chunk(3) single_image_embeds = single_image_embeds.repeat( num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) ) single_negative_image_embeds = single_negative_image_embeds.repeat( num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) ) single_image_embeds = torch.cat( [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds] ) else: single_image_embeds = single_image_embeds.repeat( num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) ) image_embeds.append(single_image_embeds) return image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_image_latents( self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: image_latents = image else: image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) if do_classifier_free_guidance: uncond_image_latents = torch.zeros_like(image_latents) image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) return image_latents @property def guidance_scale(self): return self._guidance_scale @property def image_guidance_scale(self): return self._image_guidance_scale @property def num_timesteps(self): return self._num_timesteps # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0 ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPTextModel, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import deprecate, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin): r""" Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A [`EulerDiscreteScheduler`] to be used in combination with `unet` to denoise the encoded image latents. """ model_cpu_offload_seq = "text_encoder->unet->vae" def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: EulerDiscreteScheduler, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `list(int)`): prompt to be encoded device: (`torch.device`): torch device do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_length=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_encoder_out = self.text_encoder( text_input_ids.to(device), output_hidden_states=True, ) text_embeddings = text_encoder_out.hidden_states[-1] text_pooler_out = text_encoder_out.pooler_output # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_length=True, return_tensors="pt", ) uncond_encoder_out = self.text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) uncond_embeddings = uncond_encoder_out.hidden_states[-1] uncond_pooler_out = uncond_encoder_out.pooler_output # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out]) return text_embeddings, text_pooler_out # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs(self, prompt, image, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" ) # verify batch size of prompt and image are same if image is a list or tensor if isinstance(image, (list, torch.Tensor)): if isinstance(prompt, str): batch_size = 1 else: batch_size = len(prompt) if isinstance(image, list): image_batch_size = len(image) else: image_batch_size = image.shape[0] if image.ndim == 4 else 1 if batch_size != image_batch_size: raise ValueError( f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." " Please make sure that passed `prompt` matches the batch size of `image`." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height, width) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], image: PipelineImageInput = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide image upscaling. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image` or tensor representing an image batch to be upscaled. If it's a tensor, it can be either a latent output from a Stable Diffusion model or an image tensor in the range `[-1, 1]`. It is considered a `latent` if `image.shape[1]` is `4`; otherwise, it is considered to be an image representation and encoded using this pipeline's `vae` encoder. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Examples: ```py >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline >>> import torch >>> pipeline = StableDiffusionPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 ... ) >>> pipeline.to("cuda") >>> model_id = "stabilityai/sd-x2-latent-upscaler" >>> upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16) >>> upscaler.to("cuda") >>> prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic" >>> generator = torch.manual_seed(33) >>> low_res_latents = pipeline(prompt, generator=generator, output_type="latent").images >>> with torch.no_grad(): ... image = pipeline.decode_latents(low_res_latents) >>> image = pipeline.numpy_to_pil(image)[0] >>> image.save("../images/a1.png") >>> upscaled_image = upscaler( ... prompt=prompt, ... image=low_res_latents, ... num_inference_steps=20, ... guidance_scale=0, ... generator=generator, ... ).images[0] >>> upscaled_image.save("../images/a2.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ # 1. Check inputs self.check_inputs(prompt, image, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if guidance_scale == 0: prompt = [""] * batch_size # 3. Encode input prompt text_embeddings, text_pooler_out = self._encode_prompt( prompt, device, do_classifier_free_guidance, negative_prompt ) # 4. Preprocess image image = self.image_processor.preprocess(image) image = image.to(dtype=text_embeddings.dtype, device=device) if image.shape[1] == 3: # encode image if not in latent-space yet image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps batch_multiplier = 2 if do_classifier_free_guidance else 1 image = image[None, :] if image.ndim == 3 else image image = torch.cat([image] * batch_multiplier) # 5. Add noise to image (set to be 0): # (see below notes from the author): # "the This step theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default." noise_level = torch.tensor([0.0], dtype=torch.float32, device=device) noise_level = torch.cat([noise_level] * image.shape[0]) inv_noise_level = (noise_level**2 + 1) ** (-0.5) image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None] image_cond = image_cond.to(text_embeddings.dtype) noise_level_embed = torch.cat( [ torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), ], dim=1, ) timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1) # 6. Prepare latent variables height, width = image.shape[2:] num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size, num_channels_latents, height * 2, # 2x upscale width * 2, text_embeddings.dtype, device, generator, latents, ) # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 9. Denoising loop num_warmup_steps = 0 with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): sigma = self.scheduler.sigmas[i] # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_model_input = torch.cat([scaled_model_input, image_cond], dim=1) # preconditioning parameter based on Karras et al. (2022) (table 1) timestep = torch.log(sigma) * 0.25 noise_pred = self.unet( scaled_model_input, timestep, encoder_hidden_states=text_embeddings, timestep_cond=timestep_condition, ).sample # in original repo, the output contains a variance channel that's not used noise_pred = noise_pred[:, :-1] # apply preconditioning, based on table 1 in Karras et al. (2022) inv_sigma = 1 / (sigma**2 + 1) noise_pred = inv_sigma * latent_model_input + self.scheduler.scale_model_input(sigma, t) * noise_pred # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionUpscalePipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, ): r""" Pipeline for text-guided image super-resolution using Stable Diffusion 2. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. low_res_scheduler ([`SchedulerMixin`]): A scheduler used to add initial noise to the low resolution conditioning image. It must be an instance of [`DDPMScheduler`]. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["watermarker", "safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, low_res_scheduler: DDPMScheduler, scheduler: KarrasDiffusionSchedulers, safety_checker: Optional[Any] = None, feature_extractor: Optional[CLIPImageProcessor] = None, watermarker: Optional[Any] = None, max_noise_level: int = 350, ): super().__init__() if hasattr( vae, "config" ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate is_vae_scaling_factor_set_to_0_08333 = ( hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 ) if not is_vae_scaling_factor_set_to_0_08333: deprecation_message = ( "The configuration file of the vae does not contain `scaling_factor` or it is set to" f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to" " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to" " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging" " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file" ) deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) vae.register_to_config(scaling_factor=0.08333) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, low_res_scheduler=low_res_scheduler, scheduler=scheduler, safety_checker=safety_checker, watermarker=watermarker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") self.register_to_config(max_noise_level=max_noise_level) def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, image, noise_level, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, np.ndarray) and not isinstance(image, list) ): raise ValueError( f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" ) # verify batch size of prompt and image are same if image is a list or tensor or numpy array if isinstance(image, (list, np.ndarray, torch.Tensor)): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if isinstance(image, list): image_batch_size = len(image) else: image_batch_size = image.shape[0] if batch_size != image_batch_size: raise ValueError( f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." " Please make sure that passed `prompt` matches the batch size of `image`." ) # check noise level if noise_level > self.config.max_noise_level: raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height, width) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image` or tensor representing an image batch to be upscaled. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: ```py >>> import requests >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import StableDiffusionUpscalePipeline >>> import torch >>> # load model and scheduler >>> model_id = "stabilityai/stable-diffusion-x4-upscaler" >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained( ... model_id, variant="fp16", torch_dtype=torch.float16 ... ) >>> pipeline = pipeline.to("cuda") >>> # let's download an image >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" >>> response = requests.get(url) >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") >>> low_res_img = low_res_img.resize((128, 128)) >>> prompt = "a white cat" >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] >>> upscaled_image.save("upsampled_cat.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 1. Check inputs self.check_inputs( prompt, image, noise_level, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess image image = self.image_processor.preprocess(image) image = image.to(dtype=prompt_embeds.dtype, device=device) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Add noise to image noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) image = self.low_res_scheduler.add_noise(image, noise, noise_level) batch_multiplier = 2 if do_classifier_free_guidance else 1 image = torch.cat([image] * batch_multiplier * num_images_per_prompt) noise_level = torch.cat([noise_level] * image.shape[0]) # 6. Prepare latent variables height, width = image.shape[2:] num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input, image], dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, class_labels=noise_level, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() # Ensure latents are always the same type as the VAE latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # 11. Apply watermark if output_type == "pil" and self.watermarker is not None: image = self.watermarker.apply_watermark(image) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableUnCLIPPipeline >>> pipe = StableUnCLIPPipeline.from_pretrained( ... "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16 ... ) # TODO update model path >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> images = pipe(prompt).images >>> images[0].save("astronaut_horse.png") ``` """ class StableUnCLIPPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin ): """ Pipeline for text-to-image generation using stable unCLIP. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights Args: prior_tokenizer ([`CLIPTokenizer`]): A [`CLIPTokenizer`]. prior_text_encoder ([`CLIPTextModelWithProjection`]): Frozen [`CLIPTextModelWithProjection`] text-encoder. prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. prior_scheduler ([`KarrasDiffusionSchedulers`]): Scheduler used in the prior denoising process. image_normalizer ([`StableUnCLIPImageNormalizer`]): Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image embeddings after the noise has been applied. image_noising_scheduler ([`KarrasDiffusionSchedulers`]): Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined by the `noise_level`. tokenizer ([`CLIPTokenizer`]): A [`CLIPTokenizer`]. text_encoder ([`CLIPTextModel`]): Frozen [`CLIPTextModel`] text-encoder. unet ([`UNet2DConditionModel`]): A [`UNet2DConditionModel`] to denoise the encoded image latents. scheduler ([`KarrasDiffusionSchedulers`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. """ _exclude_from_cpu_offload = ["prior", "image_normalizer"] model_cpu_offload_seq = "text_encoder->prior_text_encoder->unet->vae" # prior components prior_tokenizer: CLIPTokenizer prior_text_encoder: CLIPTextModelWithProjection prior: PriorTransformer prior_scheduler: KarrasDiffusionSchedulers # image noising components image_normalizer: StableUnCLIPImageNormalizer image_noising_scheduler: KarrasDiffusionSchedulers # regular denoising components tokenizer: CLIPTokenizer text_encoder: CLIPTextModel unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers vae: AutoencoderKL def __init__( self, # prior components prior_tokenizer: CLIPTokenizer, prior_text_encoder: CLIPTextModelWithProjection, prior: PriorTransformer, prior_scheduler: KarrasDiffusionSchedulers, # image noising components image_normalizer: StableUnCLIPImageNormalizer, image_noising_scheduler: KarrasDiffusionSchedulers, # regular denoising components tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, # vae vae: AutoencoderKL, ): super().__init__() self.register_modules( prior_tokenizer=prior_tokenizer, prior_text_encoder=prior_text_encoder, prior=prior, prior_scheduler=prior_scheduler, image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, vae=vae, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder def _encode_prior_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None, ): if text_model_output is None: batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.prior_tokenizer( prompt, padding="max_length", max_length=self.prior_tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.prior_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.prior_tokenizer.batch_decode( untruncated_ids[:, self.prior_tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.prior_tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.prior_tokenizer.model_max_length] prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device)) prompt_embeds = prior_text_encoder_output.text_embeds text_enc_hid_states = prior_text_encoder_output.last_hidden_state else: batch_size = text_model_output[0].shape[0] prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1] text_mask = text_attention_mask prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens = [""] * batch_size uncond_input = self.prior_tokenizer( uncond_tokens, padding="max_length", max_length=self.prior_tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_prior_text_encoder_output = self.prior_text_encoder( uncond_input.input_ids.to(device) ) negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_enc_hid_states.shape[1] uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1) uncond_text_enc_hid_states = uncond_text_enc_hid_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_enc_hid_states, text_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler def prepare_prior_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the prior_scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, noise_level, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." ) if prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." ) if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: raise ValueError( f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def noise_image_embeddings( self, image_embeds: torch.Tensor, noise_level: int, noise: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, ): """ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher `noise_level` increases the variance in the final un-noised images. The noise is applied in two ways: 1. A noise schedule is applied directly to the embeddings. 2. A vector of sinusoidal time embeddings are appended to the output. In both cases, the amount of noise is controlled by the same `noise_level`. The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. """ if noise is None: noise = randn_tensor( image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype ) noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) self.image_normalizer.to(image_embeds.device) image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) image_embeds = self.image_normalizer.unscale(image_embeds) noise_level = get_timestep_embedding( timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 ) # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, # but we might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. noise_level = noise_level.to(image_embeds.dtype) image_embeds = torch.cat((image_embeds, noise_level), 1) return image_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, # regular denoising process args prompt: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 20, guidance_scale: float = 10.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 0, # prior args prior_num_inference_steps: int = 25, prior_guidance_scale: float = 4.0, prior_latents: Optional[torch.Tensor] = None, clip_skip: Optional[int] = None, ): """ The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 20): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 10.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). noise_level (`int`, *optional*, defaults to `0`): The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in the final un-noised images. See [`StableUnCLIPPipeline.noise_image_embeddings`] for more details. prior_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps in the prior denoising process. More denoising steps usually lead to a higher quality image at the expense of slower inference. prior_guidance_scale (`float`, *optional*, defaults to 4.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. prior_latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image embedding generation in the prior denoising process. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, height=height, width=width, callback_steps=callback_steps, noise_level=noise_level, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] batch_size = batch_size * num_images_per_prompt device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. prior_do_classifier_free_guidance = prior_guidance_scale > 1.0 # 3. Encode input prompt prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask = self._encode_prior_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=prior_do_classifier_free_guidance, ) # 4. Prepare prior timesteps self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) prior_timesteps_tensor = self.prior_scheduler.timesteps # 5. Prepare prior latent variables embedding_dim = self.prior.config.embedding_dim prior_latents = self.prepare_latents( (batch_size, embedding_dim), prior_prompt_embeds.dtype, device, generator, prior_latents, self.prior_scheduler, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline prior_extra_step_kwargs = self.prepare_prior_extra_step_kwargs(generator, eta) # 7. Prior denoising loop for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents latent_model_input = self.prior_scheduler.scale_model_input(latent_model_input, t) predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prior_prompt_embeds, encoder_hidden_states=prior_text_encoder_hidden_states, attention_mask=prior_text_mask, ).predicted_image_embedding if prior_do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) prior_latents = self.prior_scheduler.step( predicted_image_embedding, timestep=t, sample=prior_latents, **prior_extra_step_kwargs, return_dict=False, )[0] if callback is not None and i % callback_steps == 0: callback(i, t, prior_latents) prior_latents = self.prior.post_process_latents(prior_latents) image_embeds = prior_latents # done prior # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 8. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 9. Prepare image embeddings image_embeds = self.noise_image_embeddings( image_embeds=image_embeds, noise_level=noise_level, generator=generator, ) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) # 10. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 11. Prepare latent variables num_channels_latents = self.unet.config.in_channels shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) latents = self.prepare_latents( shape=shape, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=latents, scheduler=self.scheduler, ) # 12. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 13. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=image_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import StableUnCLIPImg2ImgPipeline >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1-unclip-small", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") >>> init_image = init_image.resize((768, 512)) >>> prompt = "A fantasy landscape, trending on artstation" >>> images = pipe(init_image, prompt).images >>> images[0].save("fantasy_landscape.png") ``` """ class StableUnCLIPImg2ImgPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin ): """ Pipeline for text-guided image-to-image generation using stable unCLIP. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights Args: feature_extractor ([`CLIPImageProcessor`]): Feature extractor for image pre-processing before being encoded. image_encoder ([`CLIPVisionModelWithProjection`]): CLIP vision model for encoding images. image_normalizer ([`StableUnCLIPImageNormalizer`]): Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image embeddings after the noise has been applied. image_noising_scheduler ([`KarrasDiffusionSchedulers`]): Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined by the `noise_level`. tokenizer (`~transformers.CLIPTokenizer`): A [`~transformers.CLIPTokenizer`)]. text_encoder ([`~transformers.CLIPTextModel`]): Frozen [`~transformers.CLIPTextModel`] text-encoder. unet ([`UNet2DConditionModel`]): A [`UNet2DConditionModel`] to denoise the encoded image latents. scheduler ([`KarrasDiffusionSchedulers`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _exclude_from_cpu_offload = ["image_normalizer"] # image encoding components feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection # image noising components image_normalizer: StableUnCLIPImageNormalizer image_noising_scheduler: KarrasDiffusionSchedulers # regular denoising components tokenizer: CLIPTokenizer text_encoder: CLIPTextModel unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers vae: AutoencoderKL def __init__( self, # image encoding components feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, # image noising components image_normalizer: StableUnCLIPImageNormalizer, image_noising_scheduler: KarrasDiffusionSchedulers, # regular denoising components tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, # vae vae: AutoencoderKL, ): super().__init__() self.register_modules( feature_extractor=feature_extractor, image_encoder=image_encoder, image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, vae=vae, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds def _encode_image( self, image, device, batch_size, num_images_per_prompt, do_classifier_free_guidance, noise_level, generator, image_embeds, ): dtype = next(self.image_encoder.parameters()).dtype if isinstance(image, PIL.Image.Image): # the image embedding should repeated so it matches the total batch size of the prompt repeat_by = batch_size else: # assume the image input is already properly batched and just needs to be repeated so # it matches the num_images_per_prompt. # # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched # `image_embeds`. If those happen to be common use cases, let's think harder about # what the expected dimensions of inputs should be and how we handle the encoding. repeat_by = num_images_per_prompt if image_embeds is None: if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeds = self.image_encoder(image).image_embeds image_embeds = self.noise_image_embeddings( image_embeds=image_embeds, noise_level=noise_level, generator=generator, ) # duplicate image embeddings for each generation per prompt, using mps friendly method image_embeds = image_embeds.unsqueeze(1) bs_embed, seq_len, _ = image_embeds.shape image_embeds = image_embeds.repeat(1, repeat_by, 1) image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1) image_embeds = image_embeds.squeeze(1) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) return image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, height, width, callback_steps, noise_level, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, image_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." ) if prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." ) if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: raise ValueError( f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." ) if image is not None and image_embeds is not None: raise ValueError( "Provide either `image` or `image_embeds`. Please make sure to define only one of the two." ) if image is None and image_embeds is None: raise ValueError( "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." ) if image is not None: if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings def noise_image_embeddings( self, image_embeds: torch.Tensor, noise_level: int, noise: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, ): """ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher `noise_level` increases the variance in the final un-noised images. The noise is applied in two ways: 1. A noise schedule is applied directly to the embeddings. 2. A vector of sinusoidal time embeddings are appended to the output. In both cases, the amount of noise is controlled by the same `noise_level`. The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. """ if noise is None: noise = randn_tensor( image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype ) noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) self.image_normalizer.to(image_embeds.device) image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) image_embeds = self.image_normalizer.unscale(image_embeds) noise_level = get_timestep_embedding( timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 ) # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, # but we might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. noise_level = noise_level.to(image_embeds.dtype) image_embeds = torch.cat((image_embeds, noise_level), 1) return image_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[torch.Tensor, PIL.Image.Image] = None, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 20, guidance_scale: float = 10, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 0, image_embeds: Optional[torch.Tensor] = None, clip_skip: Optional[int] = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be used or prompt is initialized to `""`. image (`torch.Tensor` or `PIL.Image.Image`): `Image` or tensor representing an image batch. The image is encoded to its CLIP embedding which the `unet` is conditioned on. The image is _not_ encoded by the `vae` and then used as the latents in the denoising process like it is in the standard Stable Diffusion text-guided image variation process. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 20): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 10.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). noise_level (`int`, *optional*, defaults to `0`): The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in the final un-noised images. See [`StableUnCLIPPipeline.noise_image_embeddings`] for more details. image_embeds (`torch.Tensor`, *optional*): Pre-generated CLIP embeddings to condition the `unet` on. These latents are not used in the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as `latents`. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if prompt is None and prompt_embeds is None: prompt = len(image) * [""] if isinstance(image, list) else "" # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, image=image, height=height, width=width, callback_steps=callback_steps, noise_level=noise_level, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, image_embeds=image_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] batch_size = batch_size * num_images_per_prompt device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Encoder input image noise_level = torch.tensor([noise_level], device=device) image_embeds = self._encode_image( image=image, device=device, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, noise_level=noise_level, generator=generator, image_embeds=image_embeds, ) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels if latents is None: latents = self.prepare_latents( batch_size=batch_size, num_channels_latents=num_channels_latents, height=height, width=width, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=latents, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=image_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 9. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/safety_checker.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel from ...utils import logging logger = logging.get_logger(__name__) def cosine_distance(image_embeds, text_embeds): normalized_image_embeds = nn.functional.normalize(image_embeds) normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) class StableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig main_input_name = "clip_input" _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPConfig): super().__init__(config) self.vision_model = CLIPVisionModel(config.vision_config) self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) @torch.no_grad() def forward(self, clip_input, images): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() result = [] batch_size = image_embeds.shape[0] for i in range(batch_size): result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} # increase this value to create a stronger `nfsw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 for concept_idx in range(len(special_cos_dist[0])): concept_cos = special_cos_dist[i][concept_idx] concept_threshold = self.special_care_embeds_weights[concept_idx].item() result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["special_scores"][concept_idx] > 0: result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) adjustment = 0.01 for concept_idx in range(len(cos_dist[0])): concept_cos = cos_dist[i][concept_idx] concept_threshold = self.concept_embeds_weights[concept_idx].item() result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["concept_scores"][concept_idx] > 0: result_img["bad_concepts"].append(concept_idx) result.append(result_img) has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if torch.is_tensor(images) or torch.is_tensor(images[0]): images[idx] = torch.zeros_like(images[idx]) # black image else: images[idx] = np.zeros(images[idx].shape) # black image if any(has_nsfw_concepts): logger.warning( "Potential NSFW content was detected in one or more images. A black image will be returned instead." " Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts @torch.no_grad() def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) cos_dist = cosine_distance(image_embeds, self.concept_embeds) # increase this value to create a stronger `nsfw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment # special_scores = special_scores.round(decimals=3) special_care = torch.any(special_scores > 0, dim=1) special_adjustment = special_care * 0.01 special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment # concept_scores = concept_scores.round(decimals=3) has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) images[has_nsfw_concepts] = 0.0 # black image return images, has_nsfw_concepts ================================================ FILE: src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple import jax import jax.numpy as jnp from flax import linen as nn from flax.core.frozen_dict import FrozenDict from transformers import CLIPConfig, FlaxPreTrainedModel from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule def jax_cosine_distance(emb_1, emb_2, eps=1e-12): norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T return jnp.matmul(norm_emb_1, norm_emb_2.T) class FlaxStableDiffusionSafetyCheckerModule(nn.Module): config: CLIPConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) self.special_care_embeds = self.param( "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) ) self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) def __call__(self, clip_input): pooled_output = self.vision_model(clip_input)[1] image_embeds = self.visual_projection(pooled_output) special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) # increase this value to create a stronger `nfsw` filter # at the cost of increasing the possibility of filtering benign image inputs adjustment = 0.0 special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment special_scores = jnp.round(special_scores, 3) is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) # Use a lower threshold if an image has any special care concept special_adjustment = is_special_care * 0.01 concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment concept_scores = jnp.round(concept_scores, 3) has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) return has_nsfw_concepts class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): config_class = CLIPConfig main_input_name = "clip_input" module_class = FlaxStableDiffusionSafetyCheckerModule def __init__( self, config: CLIPConfig, input_shape: Optional[Tuple] = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, **kwargs, ): if input_shape is None: input_shape = (1, 224, 224, 3) module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.Array, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor clip_input = jax.random.normal(rng, input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, clip_input)["params"] return random_params def __call__( self, clip_input, params: dict = None, ): clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) return self.module.apply( {"params": params or self.params}, jnp.array(clip_input, dtype=jnp.float32), rngs={}, ) ================================================ FILE: src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): """ This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image embeddings. """ @register_to_config def __init__( self, embedding_dim: int = 768, ): super().__init__() self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) self.std = nn.Parameter(torch.ones(1, embedding_dim)) def to( self, torch_device: Optional[Union[str, torch.device]] = None, torch_dtype: Optional[torch.dtype] = None, ): self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) return self def scale(self, embeds): embeds = (embeds - self.mean) * 1.0 / self.std return embeds def unscale(self, embeds): embeds = (embeds * self.std) + self.mean return embeds ================================================ FILE: src/diffusers/pipelines/stable_diffusion_3/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_flax_available, is_torch_available, is_transformers_available, ) _dummy_objects = {} _additional_imports = {} _import_structure = {"pipeline_output": ["StableDiffusion3PipelineOutput"]} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_stable_diffusion_3"] = ["StableDiffusion3Pipeline"] _import_structure["pipeline_stable_diffusion_3_img2img"] = ["StableDiffusion3Img2ImgPipeline"] _import_structure["pipeline_stable_diffusion_3_inpaint"] = ["StableDiffusion3InpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_stable_diffusion_3 import StableDiffusion3Pipeline from .pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline from .pipeline_stable_diffusion_3_inpaint import StableDiffusion3InpaintPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) for name, value in _additional_imports.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_3/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Union import numpy as np import PIL.Image from ...utils import BaseOutput @dataclass class StableDiffusion3PipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ images: Union[List[PIL.Image.Image], np.ndarray] ================================================ FILE: src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py ================================================ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, ) from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import StableDiffusion3PipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusion3Pipeline >>> pipe = StableDiffusion3Pipeline.from_pretrained( ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> image = pipe(prompt).images[0] >>> image.save("sd3.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): r""" Args: transformer ([`SD3Transformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` as its dimension. text_encoder_2 ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. text_encoder_3 ([`T5EncoderModel`]): Frozen text-encoder. Stable Diffusion 3 uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). """ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__( self, transformer: SD3Transformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_encoder_2: CLIPTextModelWithProjection, tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) self.default_sample_size = ( self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128 ) def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if self.text_encoder_3 is None: return torch.zeros( ( batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim, ), device=device, dtype=dtype, ) text_inputs = self.tokenizer_3( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] dtype = self.text_encoder_3.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clip_skip: Optional[int] = None, clip_model_index: int = 0, ): device = device or self._execution_device clip_tokenizers = [self.tokenizer, self.tokenizer_2] clip_text_encoders = [self.text_encoder, self.text_encoder_2] tokenizer = clip_tokenizers[clip_model_index] text_encoder = clip_text_encoders[clip_model_index] prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, max_sequence_length: int = 256, lora_scale: Optional[float] = None, ): r""" Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in all text-encoders prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 prompt_3 = prompt_3 or prompt prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=0, ) prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( prompt=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=1, ) clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) t5_prompt_embed = self._get_t5_prompt_embeds( prompt=prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) clip_prompt_embeds = torch.nn.functional.pad( clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_3 = negative_prompt_3 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) negative_prompt_3 = ( batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( negative_prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=0, ) negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( negative_prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=1, ) negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) t5_negative_prompt_embed = self._get_t5_prompt_embeds( prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) negative_clip_prompt_embeds = torch.nn.functional.pad( negative_clip_prompt_embeds, (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), ) negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) negative_pooled_prompt_embeds = torch.cat( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def check_inputs( self, prompt, prompt_2, prompt_3, height, width, negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_3 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_3 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): if latents is not None: return latents.to(device=device, dtype=dtype) shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is will be used instead height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used instead negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used instead num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, prompt_3, height, width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusion3PipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py ================================================ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, Dict, List, Optional, Union import PIL.Image import torch from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, ) from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import StableDiffusion3PipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AutoPipelineForImage2Image >>> from diffusers.utils import load_image >>> device = "cuda" >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers" >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> init_image = load_image(url).resize((1024, 1024)) >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): r""" Args: transformer ([`SD3Transformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` as its dimension. text_encoder_2 ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. text_encoder_3 ([`T5EncoderModel`]): Frozen text-encoder. Stable Diffusion 3 uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). """ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__( self, transformer: SD3Transformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_encoder_2: CLIPTextModelWithProjection, tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels ) self.tokenizer_max_length = self.tokenizer.model_max_length self.default_sample_size = self.transformer.config.sample_size # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if self.text_encoder_3 is None: return torch.zeros( ( batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim, ), device=device, dtype=dtype, ) text_inputs = self.tokenizer_3( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] dtype = self.text_encoder_3.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clip_skip: Optional[int] = None, clip_model_index: int = 0, ): device = device or self._execution_device clip_tokenizers = [self.tokenizer, self.tokenizer_2] clip_text_encoders = [self.text_encoder, self.text_encoder_2] tokenizer = clip_tokenizers[clip_model_index] text_encoder = clip_text_encoders[clip_model_index] prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, max_sequence_length: int = 256, lora_scale: Optional[float] = None, ): r""" Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in all text-encoders prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 prompt_3 = prompt_3 or prompt prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=0, ) prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( prompt=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=1, ) clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) t5_prompt_embed = self._get_t5_prompt_embeds( prompt=prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) clip_prompt_embeds = torch.nn.functional.pad( clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_3 = negative_prompt_3 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) negative_prompt_3 = ( batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( negative_prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=0, ) negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( negative_prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=1, ) negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) t5_negative_prompt_embed = self._get_t5_prompt_embeds( prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) negative_clip_prompt_embeds = torch.nn.functional.pad( negative_clip_prompt_embeds, (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), ) negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) negative_pooled_prompt_embeds = torch.cat( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def check_inputs( self, prompt, prompt_2, prompt_3, strength, negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_3 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_3 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(num_inference_steps * strength, num_inference_steps) t_start = int(max(num_inference_steps - init_timestep, 0)) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == self.vae.config.latent_channels: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.scale_noise(init_latents, timestep, noise) latents = init_latents.to(device=device, dtype=dtype) return latents @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, strength: float = 0.6, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is will be used instead height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used instead negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used instead num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, prompt_3, strength, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 3. Preprocess image image = self.image_processor.preprocess(image) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Prepare latent variables if latents is None: latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, ) # 6. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusion3PipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py ================================================ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, Dict, List, Optional, Union import torch from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, ) from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import StableDiffusion3PipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusion3InpaintPipeline >>> from diffusers.utils import load_image >>> pipe = StableDiffusion3InpaintPipeline.from_pretrained( ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" >>> source = load_image(img_url) >>> mask = load_image(mask_url) >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] >>> image.save("sd3_inpainting.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusion3InpaintPipeline(DiffusionPipeline): r""" Args: transformer ([`SD3Transformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` as its dimension. text_encoder_2 ([`CLIPTextModelWithProjection`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. text_encoder_3 ([`T5EncoderModel`]): Frozen text-encoder. Stable Diffusion 3 uses [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). """ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__( self, transformer: SD3Transformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_encoder_2: CLIPTextModelWithProjection, tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels ) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels, do_normalize=False, do_binarize=True, do_convert_grayscale=True, ) self.tokenizer_max_length = self.tokenizer.model_max_length self.default_sample_size = self.transformer.config.sample_size # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if self.text_encoder_3 is None: return torch.zeros( ( batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim, ), device=device, dtype=dtype, ) text_inputs = self.tokenizer_3( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] dtype = self.text_encoder_3.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clip_skip: Optional[int] = None, clip_model_index: int = 0, ): device = device or self._execution_device clip_tokenizers = [self.tokenizer, self.tokenizer_2] clip_text_encoders = [self.text_encoder, self.text_encoder_2] tokenizer = clip_tokenizers[clip_model_index] text_encoder = clip_text_encoders[clip_model_index] prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, clip_skip: Optional[int] = None, max_sequence_length: int = 256, lora_scale: Optional[float] = None, ): r""" Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in all text-encoders prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 prompt_3 = prompt_3 or prompt prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=0, ) prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( prompt=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=clip_skip, clip_model_index=1, ) clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) t5_prompt_embed = self._get_t5_prompt_embeds( prompt=prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) clip_prompt_embeds = torch.nn.functional.pad( clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_3 = negative_prompt_3 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) negative_prompt_3 = ( batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( negative_prompt, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=0, ) negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( negative_prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, clip_skip=None, clip_model_index=1, ) negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) t5_negative_prompt_embed = self._get_t5_prompt_embeds( prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) negative_clip_prompt_embeds = torch.nn.functional.pad( negative_clip_prompt_embeds, (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), ) negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) negative_pooled_prompt_embeds = torch.cat( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) if self.text_encoder is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.check_inputs def check_inputs( self, prompt, prompt_2, prompt_3, strength, negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_3 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_3 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(num_inference_steps * strength, num_inference_steps) t_start = int(max(num_inference_steps - init_timestep, 0)) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None, timestep=None, is_strength_max=True, return_noise=False, return_image_latents=False, ): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if (image is None or timestep is None) and not is_strength_max: raise ValueError( "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." "However, either the image or the noise timestep has not been provided." ) if return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) if image.shape[1] == 16: image_latents = image else: image_latents = self._encode_vae_image(image=image, generator=generator) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise) else: noise = latents.to(device) latents = noise outputs = (latents,) if return_noise: outputs += (noise,) if return_image_latents: outputs += (image_latents,) return outputs def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = retrieve_latents(self.vae.encode(image), generator=generator) image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor return image_latents def prepare_mask_latents( self, mask, masked_image, batch_size, num_images_per_prompt, height, width, dtype, device, generator, do_classifier_free_guidance, ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt masked_image = masked_image.to(device=device, dtype=dtype) if masked_image.shape[1] == 16: masked_image_latents = masked_image else: masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, mask_image: PipelineImageInput = None, masked_image_latents: PipelineImageInput = None, height: int = None, width: int = None, padding_mask_crop: Optional[int] = None, strength: float = 0.6, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is will be used instead image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 1)`, or `(H, W)`. mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`): `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask latents tensor will ge generated by `mask_image`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. padding_mask_crop (`int`, *optional*, defaults to `None`): The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large and contain information irrelevant for inpainting, such as background. strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used instead negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `negative_prompt` is used instead num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, prompt_3, strength, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 3. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # check that number of inference steps is not < 1 - as this doesn't make sense if num_inference_steps < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 # 4. Preprocess mask and image if padding_mask_crop is not None: crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) resize_mode = "fill" else: crops_coords = None resize_mode = "default" original_image = image init_image = self.image_processor.preprocess( image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode ) init_image = init_image.to(dtype=torch.float32) # 5. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_transformer = self.transformer.config.in_channels return_image_latents = num_channels_transformer == 16 latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, return_noise=True, return_image_latents=return_image_latents, ) if return_image_latents: latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs # 6. Prepare mask latent variables mask_condition = self.mask_processor.preprocess( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) if masked_image_latents is None: masked_image = init_image * (mask_condition < 0.5) else: masked_image = masked_image_latents mask, masked_image_latents = self.prepare_mask_latents( mask_condition, masked_image, batch_size, num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, self.do_classifier_free_guidance, ) # match the inpainting pipeline and will be updated with input + mask inpainting model later if num_channels_transformer == 33: # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if ( num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels ): raise ValueError( f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.transformer` or your `mask_image` or `image` input." ) elif num_channels_transformer != 16: raise ValueError( f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}." ) # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) if num_channels_transformer == 33: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if num_channels_transformer == 16: init_latents_proper = image_latents if self.do_classifier_free_guidance: init_mask, _ = mask.chunk(2) else: init_mask = mask if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.scale_noise( init_latents_proper, torch.tensor([noise_timestep]), noise ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) mask = callback_outputs.pop("mask", mask) masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] else: image = latents do_denormalize = [True] * image.shape[0] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if padding_mask_crop is not None: image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusion3PipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from torch.nn import functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionAttendAndExcitePipeline >>> pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 ... ).to("cuda") >>> prompt = "a cat and a frog" >>> # use get_indices function to find out indices of the tokens you want to alter >>> pipe.get_indices(prompt) {0: '<|startoftext|>', 1: 'a', 2: 'cat', 3: 'and', 4: 'a', 5: 'frog', 6: '<|endoftext|>'} >>> token_indices = [2, 5] >>> seed = 6141 >>> generator = torch.Generator("cuda").manual_seed(seed) >>> images = pipe( ... prompt=prompt, ... token_indices=token_indices, ... guidance_scale=7.5, ... generator=generator, ... num_inference_steps=50, ... max_iter_to_alter=25, ... ).images >>> image = images[0] >>> image.save(f"../images/{prompt}_{seed}.png") ``` """ class AttentionStore: @staticmethod def get_empty_store(): return {"down": [], "mid": [], "up": []} def __call__(self, attn, is_cross: bool, place_in_unet: str): if self.cur_att_layer >= 0 and is_cross: if attn.shape[1] == np.prod(self.attn_res): self.step_store[place_in_unet].append(attn) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers: self.cur_att_layer = 0 self.between_steps() def between_steps(self): self.attention_store = self.step_store self.step_store = self.get_empty_store() def get_average_attention(self): average_attention = self.attention_store return average_attention def aggregate_attention(self, from_where: List[str]) -> torch.Tensor: """Aggregates the attention across the different layers and heads at the specified resolution.""" out = [] attention_maps = self.get_average_attention() for location in from_where: for item in attention_maps[location]: cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1]) out.append(cross_maps) out = torch.cat(out, dim=0) out = out.sum(0) / out.shape[0] return out def reset(self): self.cur_att_layer = 0 self.step_store = self.get_empty_store() self.attention_store = {} def __init__(self, attn_res): """ Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion process """ self.num_att_layers = -1 self.cur_att_layer = 0 self.step_store = self.get_empty_store() self.attention_store = {} self.curr_step_index = 0 self.attn_res = attn_res class AttendExciteAttnProcessor: def __init__(self, attnstore, place_in_unet): super().__init__() self.attnstore = attnstore self.place_in_unet = place_in_unet def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) is_cross = encoder_hidden_states is not None encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) # only need to store attention maps during the Attend and Excite process if attention_probs.requires_grad: self.attnstore(attention_probs, is_cross, self.place_in_unet) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion and Attend-and-Excite. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, indices, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int) indices_is_list_list_ints = ( isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int) ) if not indices_is_list_ints and not indices_is_list_list_ints: raise TypeError("`indices` must be a list of ints or a list of a list of ints") if indices_is_list_ints: indices_batch_size = 1 elif indices_is_list_list_ints: indices_batch_size = len(indices) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if indices_batch_size != prompt_batch_size: raise ValueError( f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @staticmethod def _compute_max_attention_per_index( attention_maps: torch.Tensor, indices: List[int], ) -> List[torch.Tensor]: """Computes the maximum attention value for each of the tokens we wish to alter.""" attention_for_text = attention_maps[:, :, 1:-1] attention_for_text *= 100 attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) # Shift indices since we removed the first token indices = [index - 1 for index in indices] # Extract the maximum values max_indices_list = [] for i in indices: image = attention_for_text[:, :, i] smoothing = GaussianSmoothing().to(attention_maps.device) input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect") image = smoothing(input).squeeze(0).squeeze(0) max_indices_list.append(image.max()) return max_indices_list def _aggregate_and_get_max_attention_per_token( self, indices: List[int], ): """Aggregates the attention for each token and computes the max activation value for each token to alter.""" attention_maps = self.attention_store.aggregate_attention( from_where=("up", "down", "mid"), ) max_attention_per_index = self._compute_max_attention_per_index( attention_maps=attention_maps, indices=indices, ) return max_attention_per_index @staticmethod def _compute_loss(max_attention_per_index: List[torch.Tensor]) -> torch.Tensor: """Computes the attend-and-excite loss using the maximum attention value for each token.""" losses = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index] loss = max(losses) return loss @staticmethod def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: """Update the latent according to the computed loss.""" grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] latents = latents - step_size * grad_cond return latents def _perform_iterative_refinement_step( self, latents: torch.Tensor, indices: List[int], loss: torch.Tensor, threshold: float, text_embeddings: torch.Tensor, step_size: float, t: int, max_refinement_steps: int = 20, ): """ Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent code according to our loss objective until the given threshold is reached for all tokens. """ iteration = 0 target_loss = max(0, 1.0 - threshold) while loss > target_loss: iteration += 1 latents = latents.clone().detach().requires_grad_(True) self.unet(latents, t, encoder_hidden_states=text_embeddings).sample self.unet.zero_grad() # Get max activation value for each subject token max_attention_per_index = self._aggregate_and_get_max_attention_per_token( indices=indices, ) loss = self._compute_loss(max_attention_per_index) if loss != 0: latents = self._update_latent(latents, loss, step_size) logger.info(f"\t Try {iteration}. loss: {loss}") if iteration >= max_refinement_steps: logger.info(f"\t Exceeded max number of iterations ({max_refinement_steps})! ") break # Run one more time but don't compute gradients and update the latents. # We just need to compute the new loss - the grad update will occur below latents = latents.clone().detach().requires_grad_(True) _ = self.unet(latents, t, encoder_hidden_states=text_embeddings).sample self.unet.zero_grad() # Get max activation value for each subject token max_attention_per_index = self._aggregate_and_get_max_attention_per_token( indices=indices, ) loss = self._compute_loss(max_attention_per_index) logger.info(f"\t Finished with loss of: {loss}") return loss, latents, max_attention_per_index def register_attention_control(self): attn_procs = {} cross_att_count = 0 for name in self.unet.attn_processors.keys(): if name.startswith("mid_block"): place_in_unet = "mid" elif name.startswith("up_blocks"): place_in_unet = "up" elif name.startswith("down_blocks"): place_in_unet = "down" else: continue cross_att_count += 1 attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet) self.unet.set_attn_processor(attn_procs) self.attention_store.num_att_layers = cross_att_count def get_indices(self, prompt: str) -> Dict[str, int]: """Utility function to list the indices of the tokens you wish to alte""" ids = self.tokenizer(prompt).input_ids indices = {i: tok for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))} return indices @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], token_indices: Union[List[int], List[List[int]]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, max_iter_to_alter: int = 25, thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8}, scale_factor: int = 20, attn_res: Optional[Tuple[int]] = (16, 16), clip_skip: Optional[int] = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. token_indices (`List[int]`): The token indices to alter with attend-and-excite. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). max_iter_to_alter (`int`, *optional*, defaults to `25`): Number of denoising steps to apply attend-and-excite. The `max_iter_to_alter` denoising steps are when attend-and-excite is applied. For example, if `max_iter_to_alter` is `25` and there are a total of `30` denoising steps, the first `25` denoising steps applies attend-and-excite and the last `5` will not. thresholds (`dict`, *optional*, defaults to `{0: 0.05, 10: 0.5, 20: 0.8}`): Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in. scale_factor (`int`, *optional*, default to 20): Scale factor to control the step size of each attend-and-excite update. attn_res (`tuple`, *optional*, default computed from width and height): The 2D resolution of the semantic attention map. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, token_indices, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) if attn_res is None: attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32)) self.attention_store = AttentionStore(attn_res) original_attn_proc = self.unet.attn_processors self.register_attention_control() # default config for step size from original repo scale_range = np.linspace(1.0, 0.5, len(self.scheduler.timesteps)) step_size = scale_factor * np.sqrt(scale_range) text_embeddings = ( prompt_embeds[batch_size * num_images_per_prompt :] if do_classifier_free_guidance else prompt_embeds ) if isinstance(token_indices[0], int): token_indices = [token_indices] indices = [] for ind in token_indices: indices = indices + [ind] * num_images_per_prompt # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Attend and excite process with torch.enable_grad(): latents = latents.clone().detach().requires_grad_(True) updated_latents = [] for latent, index, text_embedding in zip(latents, indices, text_embeddings): # Forward pass of denoising with text conditioning latent = latent.unsqueeze(0) text_embedding = text_embedding.unsqueeze(0) self.unet( latent, t, encoder_hidden_states=text_embedding, cross_attention_kwargs=cross_attention_kwargs, ).sample self.unet.zero_grad() # Get max activation value for each subject token max_attention_per_index = self._aggregate_and_get_max_attention_per_token( indices=index, ) loss = self._compute_loss(max_attention_per_index=max_attention_per_index) # If this is an iterative refinement step, verify we have reached the desired threshold for all if i in thresholds.keys() and loss > 1.0 - thresholds[i]: loss, latent, max_attention_per_index = self._perform_iterative_refinement_step( latents=latent, indices=index, loss=loss, threshold=thresholds[i], text_embeddings=text_embedding, step_size=step_size[i], t=t, ) # Perform gradient update if i < max_iter_to_alter: if loss != 0: latent = self._update_latent( latents=latent, loss=loss, step_size=step_size[i], ) logger.info(f"Iteration {i} | Loss: {loss:0.4f}") updated_latents.append(latent) latents = torch.cat(updated_latents, dim=0) # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 8. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) self.maybe_free_model_hooks() # make sure to set the original attention processors back self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) class GaussianSmoothing(torch.nn.Module): """ Arguments: Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel. dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial). """ # channels=1, kernel_size=kernel_size, sigma=sigma, dim=2 def __init__( self, channels: int = 1, kernel_size: int = 3, sigma: float = 0.5, dim: int = 2, ): super().__init__() if isinstance(kernel_size, int): kernel_size = [kernel_size] * dim if isinstance(sigma, float): sigma = [sigma] * dim # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) self.register_buffer("weight", kernel) self.groups = channels if dim == 1: self.conv = F.conv1d elif dim == 2: self.conv = F.conv2d elif dim == 3: self.conv = F.conv3d else: raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)) def forward(self, input): """ Arguments: Apply gaussian filter to input. input (torch.Tensor): Input to apply gaussian filter on. Returns: filtered (torch.Tensor): Filtered output. """ return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_diffedit/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py ================================================ # Copyright 2024 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, USE_PEFT_BACKEND, BaseOutput, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class DiffEditInversionPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: latents (`torch.Tensor`) inverted latents tensor images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `num_timesteps * batch_size` or numpy array of shape `(num_timesteps, batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ latents: torch.Tensor images: Union[List[PIL.Image.Image], np.ndarray] EXAMPLE_DOC_STRING = """ ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionDiffEditPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" >>> init_image = download_image(img_url).resize((768, 768)) >>> pipeline = StableDiffusionDiffEditPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 ... ) >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) >>> pipeline.enable_model_cpu_offload() >>> mask_prompt = "A bowl of fruits" >>> prompt = "A bowl of pears" >>> mask_image = pipeline.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) >>> image_latents = pipeline.invert(image=init_image, prompt=mask_prompt).latents >>> image = pipeline(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0] ``` """ EXAMPLE_INVERT_DOC_STRING = """ ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionDiffEditPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" >>> init_image = download_image(img_url).resize((768, 768)) >>> pipeline = StableDiffusionDiffEditPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 ... ) >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) >>> pipeline.enable_model_cpu_offload() >>> prompt = "A bowl of fruits" >>> inverted_latents = pipeline.invert(image=init_image, prompt=prompt).latents ``` """ def auto_corr_loss(hidden_states, generator=None): reg_loss = 0.0 for i in range(hidden_states.shape[0]): for j in range(hidden_states.shape[1]): noise = hidden_states[i : i + 1, j : j + 1, :, :] while True: roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 if noise.shape[2] <= 8: break noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2) return reg_loss def kl_divergence(hidden_states): return hidden_states.var() + hidden_states.mean() ** 2 - 1 - torch.log(hidden_states.var() + 1e-7) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image def preprocess_mask(mask, batch_size: int = 1): if not isinstance(mask, torch.Tensor): # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list): if isinstance(mask[0], PIL.Image.Image): mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask] if isinstance(mask[0], np.ndarray): mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0) mask = torch.from_numpy(mask) elif isinstance(mask[0], torch.Tensor): mask = torch.stack(mask, dim=0) if mask[0].ndim < 3 else torch.cat(mask, dim=0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) # Check mask shape if batch_size > 1: if mask.shape[0] == 1: mask = torch.cat([mask] * batch_size) elif mask.shape[0] > 1 and mask.shape[0] != batch_size: raise ValueError( f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} " f"inferred by prompt inputs" ) if mask.shape[1] != 1: raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("`mask_image` should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 return mask class StableDiffusionDiffEditPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin ): r""" This is an experimental feature! Pipeline for text-guided image inpainting using Stable Diffusion and DiffEdit. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading and saving methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. inverse_scheduler ([`DDIMInverseScheduler`]): A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, inverse_scheduler: DDIMInverseScheduler, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration" " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" " Hub, it would be very nice if you could open a Pull request for the" " `scheduler/scheduler_config.json` file" ) deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["skip_prk_steps"] = True scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, inverse_scheduler=inverse_scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (strength is None) or (strength is not None and (strength < 0 or strength > 1)): raise ValueError( f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def check_source_inputs( self, source_prompt=None, source_negative_prompt=None, source_prompt_embeds=None, source_negative_prompt_embeds=None, ): if source_prompt is not None and source_prompt_embeds is not None: raise ValueError( f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}." " Please make sure to only forward one of the two." ) elif source_prompt is None and source_prompt_embeds is None: raise ValueError( "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined." ) elif source_prompt is not None and ( not isinstance(source_prompt, str) and not isinstance(source_prompt, list) ): raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}") if source_negative_prompt is not None and source_negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:" f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two." ) if source_prompt_embeds is not None and source_negative_prompt_embeds is not None: if source_prompt_embeds.shape != source_negative_prompt_embeds.shape: raise ValueError( "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed" f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !=" f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def get_inverse_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) # safety for t_start overflow to prevent empty timsteps slice if t_start == 0: return self.inverse_scheduler.timesteps, num_inference_steps timesteps = self.inverse_scheduler.timesteps[:-t_start] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) if image.shape[1] == 4: latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] latents = torch.cat(latents, dim=0) else: latents = self.vae.encode(image).latent_dist.sample(generator) latents = self.vae.config.scaling_factor * latents if batch_size != latents.shape[0]: if batch_size % latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_latents_per_image = batch_size // latents.shape[0] latents = torch.cat([latents] * additional_latents_per_image, dim=0) else: raise ValueError( f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." ) else: latents = torch.cat([latents], dim=0) return latents def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): pred_type = self.inverse_scheduler.config.prediction_type alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t if pred_type == "epsilon": return model_output elif pred_type == "sample": return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) elif pred_type == "v_prediction": return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" ) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def generate_mask( self, image: Union[torch.Tensor, PIL.Image.Image] = None, target_prompt: Optional[Union[str, List[str]]] = None, target_negative_prompt: Optional[Union[str, List[str]]] = None, target_prompt_embeds: Optional[torch.Tensor] = None, target_negative_prompt_embeds: Optional[torch.Tensor] = None, source_prompt: Optional[Union[str, List[str]]] = None, source_negative_prompt: Optional[Union[str, List[str]]] = None, source_prompt_embeds: Optional[torch.Tensor] = None, source_negative_prompt_embeds: Optional[torch.Tensor] = None, num_maps_per_mask: Optional[int] = 10, mask_encode_strength: Optional[float] = 0.5, mask_thresholding_ratio: Optional[float] = 3.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "np", cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Generate a latent mask given a mask prompt, a target prompt, and an image. Args: image (`PIL.Image.Image`): `Image` or tensor representing an image batch to be used for computing the mask. target_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide semantic mask generation. If not defined, you need to pass `prompt_embeds`. target_negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). target_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. target_negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. source_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide semantic mask generation using DiffEdit. If not defined, you need to pass `source_prompt_embeds` or `source_image` instead. source_negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide semantic mask generation away from using DiffEdit. If not defined, you need to pass `source_negative_prompt_embeds` or `source_image` instead. source_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from `source_prompt` input argument. source_negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from `source_negative_prompt` input argument. num_maps_per_mask (`int`, *optional*, defaults to 10): The number of noise maps sampled to generate the semantic mask using DiffEdit. mask_encode_strength (`float`, *optional*, defaults to 0.5): The strength of the noise maps sampled to generate the semantic mask using DiffEdit. Must be between 0 and 1. mask_thresholding_ratio (`float`, *optional*, defaults to 3.0): The maximum multiple of the mean absolute difference used to clamp the semantic guidance map before mask binarization. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`~models.attention_processor.AttnProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Examples: Returns: `List[PIL.Image.Image]` or `np.array`: When returning a `List[PIL.Image.Image]`, the list consists of a batch of single-channel binary images with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`. If it's `np.array`, the shape is `(batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)`. """ # 1. Check inputs (Provide dummy argument for callback_steps) self.check_inputs( target_prompt, mask_encode_strength, 1, target_negative_prompt, target_prompt_embeds, target_negative_prompt_embeds, ) self.check_source_inputs( source_prompt, source_negative_prompt, source_prompt_embeds, source_negative_prompt_embeds, ) if (num_maps_per_mask is None) or ( num_maps_per_mask is not None and (not isinstance(num_maps_per_mask, int) or num_maps_per_mask <= 0) ): raise ValueError( f"`num_maps_per_mask` has to be a positive integer but is {num_maps_per_mask} of type" f" {type(num_maps_per_mask)}." ) if mask_thresholding_ratio is None or mask_thresholding_ratio <= 0: raise ValueError( f"`mask_thresholding_ratio` has to be positive but is {mask_thresholding_ratio} of type" f" {type(mask_thresholding_ratio)}." ) # 2. Define call parameters if target_prompt is not None and isinstance(target_prompt, str): batch_size = 1 elif target_prompt is not None and isinstance(target_prompt, list): batch_size = len(target_prompt) else: batch_size = target_prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompts (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None) target_negative_prompt_embeds, target_prompt_embeds = self.encode_prompt( target_prompt, device, num_maps_per_mask, do_classifier_free_guidance, target_negative_prompt, prompt_embeds=target_prompt_embeds, negative_prompt_embeds=target_negative_prompt_embeds, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: target_prompt_embeds = torch.cat([target_negative_prompt_embeds, target_prompt_embeds]) source_negative_prompt_embeds, source_prompt_embeds = self.encode_prompt( source_prompt, device, num_maps_per_mask, do_classifier_free_guidance, source_negative_prompt, prompt_embeds=source_prompt_embeds, negative_prompt_embeds=source_negative_prompt_embeds, ) if do_classifier_free_guidance: source_prompt_embeds = torch.cat([source_negative_prompt_embeds, source_prompt_embeds]) # 4. Preprocess image image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, _ = self.get_timesteps(num_inference_steps, mask_encode_strength, device) encode_timestep = timesteps[0] # 6. Prepare image latents and add noise with specified strength image_latents = self.prepare_image_latents( image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator ) noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=self.vae.dtype) image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep) latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep) # 7. Predict the noise residual prompt_embeds = torch.cat([source_prompt_embeds, target_prompt_embeds]) noise_pred = self.unet( latent_model_input, encode_timestep, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample if do_classifier_free_guidance: noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) noise_pred_source = noise_pred_neg_src + guidance_scale * (noise_pred_source - noise_pred_neg_src) noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond) else: noise_pred_source, noise_pred_target = noise_pred.chunk(2) # 8. Compute the mask from the absolute difference of predicted noise residuals # TODO: Consider smoothing mask guidance map mask_guidance_map = ( torch.abs(noise_pred_target - noise_pred_source) .reshape(batch_size, num_maps_per_mask, *noise_pred_target.shape[-3:]) .mean([1, 2]) ) clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio semantic_mask_image = mask_guidance_map.clamp(0, clamp_magnitude) / clamp_magnitude semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1) mask_image = semantic_mask_image.cpu().numpy() # 9. Convert to Numpy array or PIL. if output_type == "pil": mask_image = self.image_processor.numpy_to_pil(mask_image) # Offload all models self.maybe_free_model_hooks() return mask_image @torch.no_grad() @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) def invert( self, prompt: Optional[Union[str, List[str]]] = None, image: Union[torch.Tensor, PIL.Image.Image] = None, num_inference_steps: int = 50, inpaint_strength: float = 0.8, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, decode_latents: bool = False, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, lambda_auto_corr: float = 20.0, lambda_kl: float = 20.0, num_reg_steps: int = 0, num_auto_corr_rolls: int = 5, ): r""" Generate inverted latents given a prompt and image. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`PIL.Image.Image`): `Image` or tensor representing an image batch to produce the inverted latents guided by `prompt`. inpaint_strength (`float`, *optional*, defaults to 0.8): Indicates extent of the noising process to run latent inversion. Must be between 0 and 1. When `inpaint_strength` is 1, the inversion process is run for the full number of iterations specified in `num_inference_steps`. `image` is used as a reference for the inversion process, and adding more noise increases `inpaint_strength`. If `inpaint_strength` is 0, no inpainting occurs. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. decode_latents (`bool`, *optional*, defaults to `False`): Whether or not to decode the inverted latents into a generated image. Setting this argument to `True` decodes all inverted latents for each timestep into a list of generated images. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`~models.attention_processor.AttnProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). lambda_auto_corr (`float`, *optional*, defaults to 20.0): Lambda parameter to control auto correction. lambda_kl (`float`, *optional*, defaults to 20.0): Lambda parameter to control Kullback-Leibler divergence output. num_reg_steps (`int`, *optional*, defaults to 0): Number of regularization loss steps. num_auto_corr_rolls (`int`, *optional*, defaults to 5): Number of auto correction roll steps. Examples: Returns: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is the inverted latents tensors ordered by increasing noise, and the second is the corresponding decoded images if `decode_latents` is `True`, otherwise `None`. """ # 1. Check inputs self.check_inputs( prompt, inpaint_strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Preprocess image image = self.image_processor.preprocess(image) # 4. Prepare latent variables num_images_per_prompt = 1 latents = self.prepare_image_latents( image, batch_size * num_images_per_prompt, self.vae.dtype, device, generator ) # 5. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 6. Prepare timesteps self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_inverse_timesteps(num_inference_steps, inpaint_strength, device) # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order inverted_latents = [] with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero) if num_reg_steps > 0: with torch.enable_grad(): for _ in range(num_reg_steps): if lambda_auto_corr > 0: for _ in range(num_auto_corr_rolls): var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_ac = auto_corr_loss(var_epsilon, generator=generator) l_ac.backward() grad = var.grad.detach() / num_auto_corr_rolls noise_pred = noise_pred - lambda_auto_corr * grad if lambda_kl > 0: var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_kld = kl_divergence(var_epsilon) l_kld.backward() grad = var.grad.detach() noise_pred = noise_pred - lambda_kl * grad noise_pred = noise_pred.detach() # compute the previous noisy sample x_t -> x_t-1 latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample inverted_latents.append(latents.detach().clone()) # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) assert len(inverted_latents) == len(timesteps) latents = torch.stack(list(reversed(inverted_latents)), 1) # 8. Post-processing image = None if decode_latents: image = self.decode_latents(latents.flatten(0, 1)) # 9. Convert to PIL. if decode_latents and output_type == "pil": image = self.image_processor.numpy_to_pil(image) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (latents, image) return DiffEditInversionPipelineOutput(latents=latents, images=image) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, mask_image: Union[torch.Tensor, PIL.Image.Image] = None, image_latents: Union[torch.Tensor, PIL.Image.Image] = None, inpaint_strength: Optional[float] = 0.8, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. mask_image (`PIL.Image.Image`): `Image` or tensor representing an image batch to mask the generated image. White pixels in the mask are repainted, while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, 1, H, W)`. image_latents (`PIL.Image.Image` or `torch.Tensor`): Partially noised image latents from the inversion process to be used as inputs for image generation. inpaint_strength (`float`, *optional*, defaults to 0.8): Indicates extent to inpaint the masked area. Must be between 0 and 1. When `inpaint_strength` is 1, the denoising process is run on the masked area for the full number of iterations specified in `num_inference_steps`. `image_latents` is used as a reference for the masked area, and adding more noise to a region increases `inpaint_strength`. If `inpaint_strength` is 0, no inpainting occurs. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 1. Check inputs self.check_inputs( prompt, inpaint_strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) if mask_image is None: raise ValueError( "`mask_image` input cannot be undefined. Use `generate_mask()` to compute `mask_image` from text prompts." ) if image_latents is None: raise ValueError( "`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images." ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess mask mask_image = preprocess_mask(mask_image, batch_size) latent_height, latent_width = mask_image.shape[-2:] mask_image = torch.cat([mask_image] * num_images_per_prompt) mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) # 6. Preprocess image latents if isinstance(image_latents, list) and any(isinstance(l, torch.Tensor) and l.ndim == 5 for l in image_latents): image_latents = torch.cat(image_latents).detach() elif isinstance(image_latents, torch.Tensor) and image_latents.ndim == 5: image_latents = image_latents.detach() else: image_latents = self.image_processor.preprocess(image_latents).detach() latent_shape = (self.vae.config.latent_channels, latent_height, latent_width) if image_latents.shape[-3:] != latent_shape: raise ValueError( f"Each latent image in `image_latents` must have shape {latent_shape}, " f"but has shape {image_latents.shape[-3:]}" ) if image_latents.ndim == 4: image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape) if image_latents.shape[:2] != (batch_size, len(timesteps)): raise ValueError( f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)}" f" timesteps, but has batch size {image_latents.shape[0]} with latent images from" f" {image_latents.shape[1]} timesteps." ) image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1) image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop latents = image_latents[0].clone() num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # mask with inverted latents from appropriate timestep - use original image latent for last step latents = latents * mask_image + image_latents[i] * (1 - mask_image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_gligen/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"] _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py ================================================ # Copyright 2024 The GLIGEN Authors and HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention import GatedSelfAttentionDense from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionGLIGENPipeline >>> from diffusers.utils import load_image >>> # Insert objects described by text at the region defined by bounding boxes >>> pipe = StableDiffusionGLIGENPipeline.from_pretrained( ... "masterful/gligen-1-4-inpainting-text-box", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> input_image = load_image( ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/livingroom_modern.png" ... ) >>> prompt = "a birthday cake" >>> boxes = [[0.2676, 0.6088, 0.4773, 0.7183]] >>> phrases = ["a birthday cake"] >>> images = pipe( ... prompt=prompt, ... gligen_phrases=phrases, ... gligen_inpaint_image=input_image, ... gligen_boxes=boxes, ... gligen_scheduled_sampling_beta=1, ... output_type="pil", ... num_inference_steps=50, ... ).images >>> images[0].save("./gligen-1-4-inpainting-text-box.jpg") >>> # Generate an image described by the prompt and >>> # insert objects described by text at the region defined by bounding boxes >>> pipe = StableDiffusionGLIGENPipeline.from_pretrained( ... "masterful/gligen-1-4-generation-text-box", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a waterfall and a modern high speed train running through the tunnel in a beautiful forest with fall foliage" >>> boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]] >>> phrases = ["a waterfall", "a modern high speed train running through the tunnel"] >>> images = pipe( ... prompt=prompt, ... gligen_phrases=phrases, ... gligen_boxes=boxes, ... gligen_scheduled_sampling_beta=1, ... output_type="pil", ... num_inference_steps=50, ... ).images >>> images[0].save("./gligen-1-4-generation-text-box.jpg") ``` """ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN). This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] model_cpu_offload_seq = "text_encoder->unet->vae" _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, gligen_phrases, gligen_boxes, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if len(gligen_phrases) != len(gligen_boxes): raise ValueError( "length of `gligen_phrases` and `gligen_boxes` has to be same, but" f" got: `gligen_phrases` {len(gligen_phrases)} != `gligen_boxes` {len(gligen_boxes)}" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def enable_fuser(self, enabled=True): for module in self.unet.modules(): if type(module) is GatedSelfAttentionDense: module.enabled = enabled def draw_inpaint_mask_from_boxes(self, boxes, size): inpaint_mask = torch.ones(size[0], size[1]) for box in boxes: x0, x1 = box[0] * size[0], box[2] * size[0] y0, y1 = box[1] * size[1], box[3] * size[1] inpaint_mask[int(y0) : int(y1), int(x0) : int(x1)] = 0 return inpaint_mask def crop(self, im, new_width, new_height): width, height = im.size left = (width - new_width) / 2 top = (height - new_height) / 2 right = (width + new_width) / 2 bottom = (height + new_height) / 2 return im.crop((left, top, right, bottom)) def target_size_center_crop(self, im, new_hw): width, height = im.size if width != height: im = self.crop(im, min(height, width), min(height, width)) return im.resize((new_hw, new_hw), PIL.Image.LANCZOS) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, gligen_scheduled_sampling_beta: float = 0.3, gligen_phrases: List[str] = None, gligen_boxes: List[List[float]] = None, gligen_inpaint_image: Optional[PIL.Image.Image] = None, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. gligen_phrases (`List[str]`): The phrases to guide what to include in each of the regions defined by the corresponding `gligen_boxes`. There should only be one phrase per bounding box. gligen_boxes (`List[List[float]]`): The bounding boxes that identify rectangular regions of the image that are going to be filled with the content described by the corresponding `gligen_phrases`. Each rectangular box is defined as a `List[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1]. gligen_inpaint_image (`PIL.Image.Image`, *optional*): The input image, if provided, is inpainted with objects described by the `gligen_boxes` and `gligen_phrases`. Otherwise, it is treated as a generation task on a blank input image. gligen_scheduled_sampling_beta (`float`, defaults to 0.3): Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image Generation](https://arxiv.org/pdf/2301.07093.pdf). Scheduled Sampling factor is only varied for scheduled sampling during inference for improved quality and controllability. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, gligen_phrases, gligen_boxes, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 5.1 Prepare GLIGEN variables max_objs = 30 if len(gligen_boxes) > max_objs: warnings.warn( f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.", FutureWarning, ) gligen_phrases = gligen_phrases[:max_objs] gligen_boxes = gligen_boxes[:max_objs] # prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask) # Get tokens for phrases from pre-trained CLIPTokenizer tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device) # For the token, we use the same pre-trained text encoder # to obtain its text feature _text_embeddings = self.text_encoder(**tokenizer_inputs).pooler_output n_objs = len(gligen_boxes) # For each entity, described in phrases, is denoted with a bounding box, # we represent the location information as (xmin,ymin,xmax,ymax) boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) boxes[:n_objs] = torch.tensor(gligen_boxes) text_embeddings = torch.zeros( max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype ) text_embeddings[:n_objs] = _text_embeddings # Generate a mask for each object that is entity described by phrases masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) masks[:n_objs] = 1 repeat_batch = batch_size * num_images_per_prompt boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone() text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone() masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone() if do_classifier_free_guidance: repeat_batch = repeat_batch * 2 boxes = torch.cat([boxes] * 2) text_embeddings = torch.cat([text_embeddings] * 2) masks = torch.cat([masks] * 2) masks[: repeat_batch // 2] = 0 if cross_attention_kwargs is None: cross_attention_kwargs = {} cross_attention_kwargs["gligen"] = {"boxes": boxes, "positive_embeddings": text_embeddings, "masks": masks} # Prepare latent variables for GLIGEN inpainting if gligen_inpaint_image is not None: # if the given input image is not of the same size as expected by VAE # center crop and resize the input image to expected shape if gligen_inpaint_image.size != (self.vae.sample_size, self.vae.sample_size): gligen_inpaint_image = self.target_size_center_crop(gligen_inpaint_image, self.vae.sample_size) # Convert a single image into a batch of images with a batch size of 1 # The resulting shape becomes (1, C, H, W), where C is the number of channels, # and H and W are the height and width of the image. # scales the pixel values to a range [-1, 1] gligen_inpaint_image = self.image_processor.preprocess(gligen_inpaint_image) gligen_inpaint_image = gligen_inpaint_image.to(dtype=self.vae.dtype, device=self.vae.device) # Run AutoEncoder to get corresponding latents gligen_inpaint_latent = self.vae.encode(gligen_inpaint_image).latent_dist.sample() gligen_inpaint_latent = self.vae.config.scaling_factor * gligen_inpaint_latent # Generate an inpainting mask # pixel value = 0, where the object is present (defined by bounding boxes above) # 1, everywhere else gligen_inpaint_mask = self.draw_inpaint_mask_from_boxes(gligen_boxes, gligen_inpaint_latent.shape[2:]) gligen_inpaint_mask = gligen_inpaint_mask.to( dtype=gligen_inpaint_latent.dtype, device=gligen_inpaint_latent.device ) gligen_inpaint_mask = gligen_inpaint_mask[None, None] gligen_inpaint_mask_addition = torch.cat( (gligen_inpaint_latent * gligen_inpaint_mask, gligen_inpaint_mask), dim=1 ) # Convert a single mask into a batch of masks with a batch size of 1 gligen_inpaint_mask_addition = gligen_inpaint_mask_addition.expand(repeat_batch, -1, -1, -1).clone() num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps)) self.enable_fuser(True) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Scheduled sampling if i == num_grounding_steps: self.enable_fuser(False) if latents.shape[1] != 4: latents = torch.randn_like(latents[:, :4]) if gligen_inpaint_image is not None: gligen_inpaint_latent_with_noise = ( self.scheduler.add_noise( gligen_inpaint_latent, torch.randn_like(gligen_inpaint_latent), torch.tensor([t]) ) .expand(latents.shape[0], -1, -1, -1) .clone() ) latents = gligen_inpaint_latent_with_noise * gligen_inpaint_mask + latents * ( 1 - gligen_inpaint_mask ) # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if gligen_inpaint_image is not None: latent_model_input = torch.cat((latent_model_input, gligen_inpaint_mask_addition), dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py ================================================ # Copyright 2024 The GLIGEN Authors and HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import PIL.Image import torch from transformers import ( CLIPImageProcessor, CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention import GatedSelfAttentionDense from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.clip_image_project_model import CLIPImageProjection from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionGLIGENTextImagePipeline >>> from diffusers.utils import load_image >>> # Insert objects described by image at the region defined by bounding boxes >>> pipe = StableDiffusionGLIGENTextImagePipeline.from_pretrained( ... "anhnct/Gligen_Inpainting_Text_Image", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> input_image = load_image( ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/livingroom_modern.png" ... ) >>> prompt = "a backpack" >>> boxes = [[0.2676, 0.4088, 0.4773, 0.7183]] >>> phrases = None >>> gligen_image = load_image( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/backpack.jpeg" ... ) >>> images = pipe( ... prompt=prompt, ... gligen_phrases=phrases, ... gligen_inpaint_image=input_image, ... gligen_boxes=boxes, ... gligen_images=[gligen_image], ... gligen_scheduled_sampling_beta=1, ... output_type="pil", ... num_inference_steps=50, ... ).images >>> images[0].save("./gligen-inpainting-text-image-box.jpg") >>> # Generate an image described by the prompt and >>> # insert objects described by text and image at the region defined by bounding boxes >>> pipe = StableDiffusionGLIGENTextImagePipeline.from_pretrained( ... "anhnct/Gligen_Text_Image", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a flower sitting on the beach" >>> boxes = [[0.0, 0.09, 0.53, 0.76]] >>> phrases = ["flower"] >>> gligen_image = load_image( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/pexels-pixabay-60597.jpg" ... ) >>> images = pipe( ... prompt=prompt, ... gligen_phrases=phrases, ... gligen_images=[gligen_image], ... gligen_boxes=boxes, ... gligen_scheduled_sampling_beta=1, ... output_type="pil", ... num_inference_steps=50, ... ).images >>> images[0].save("./gligen-generation-text-image-box.jpg") >>> # Generate an image described by the prompt and >>> # transfer style described by image at the region defined by bounding boxes >>> pipe = StableDiffusionGLIGENTextImagePipeline.from_pretrained( ... "anhnct/Gligen_Text_Image", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a dragon flying on the sky" >>> boxes = [[0.4, 0.2, 1.0, 0.8], [0.0, 1.0, 0.0, 1.0]] # Set `[0.0, 1.0, 0.0, 1.0]` for the style >>> gligen_image = load_image( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/landscape.png" ... ) >>> gligen_placeholder = load_image( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/landscape.png" ... ) >>> images = pipe( ... prompt=prompt, ... gligen_phrases=[ ... "dragon", ... "placeholder", ... ], # Can use any text instead of `placeholder` token, because we will use mask here ... gligen_images=[ ... gligen_placeholder, ... gligen_image, ... ], # Can use any image in gligen_placeholder, because we will use mask here ... input_phrases_mask=[1, 0], # Set 0 for the placeholder token ... input_images_mask=[0, 1], # Set 0 for the placeholder image ... gligen_boxes=boxes, ... gligen_scheduled_sampling_beta=1, ... output_type="pil", ... num_inference_steps=50, ... ).images >>> images[0].save("./gligen-generation-text-image-box-style-transfer.jpg") ``` """ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN). This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. processor ([`~transformers.CLIPProcessor`]): A `CLIPProcessor` to procces reference image. image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): Frozen image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). image_project ([`CLIPImageProjection`]): A `CLIPImageProjection` to project image embedding into phrases embedding space. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, processor: CLIPProcessor, image_encoder: CLIPVisionModelWithProjection, image_project: CLIPImageProjection, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, image_encoder=image_encoder, processor=processor, image_project=image_project, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def enable_fuser(self, enabled=True): for module in self.unet.modules(): if type(module) is GatedSelfAttentionDense: module.enabled = enabled def draw_inpaint_mask_from_boxes(self, boxes, size): """ Create an inpainting mask based on given boxes. This function generates an inpainting mask using the provided boxes to mark regions that need to be inpainted. """ inpaint_mask = torch.ones(size[0], size[1]) for box in boxes: x0, x1 = box[0] * size[0], box[2] * size[0] y0, y1 = box[1] * size[1], box[3] * size[1] inpaint_mask[int(y0) : int(y1), int(x0) : int(x1)] = 0 return inpaint_mask def crop(self, im, new_width, new_height): """ Crop the input image to the specified dimensions. """ width, height = im.size left = (width - new_width) / 2 top = (height - new_height) / 2 right = (width + new_width) / 2 bottom = (height + new_height) / 2 return im.crop((left, top, right, bottom)) def target_size_center_crop(self, im, new_hw): """ Crop and resize the image to the target size while keeping the center. """ width, height = im.size if width != height: im = self.crop(im, min(height, width), min(height, width)) return im.resize((new_hw, new_hw), PIL.Image.LANCZOS) def complete_mask(self, has_mask, max_objs, device): """ Based on the input mask corresponding value `0 or 1` for each phrases and image, mask the features corresponding to phrases and images. """ mask = torch.ones(1, max_objs).type(self.text_encoder.dtype).to(device) if has_mask is None: return mask if isinstance(has_mask, int): return mask * has_mask else: for idx, value in enumerate(has_mask): mask[0, idx] = value return mask def get_clip_feature(self, input, normalize_constant, device, is_image=False): """ Get image and phrases embedding by using CLIP pretrain model. The image embedding is transformed into the phrases embedding space through a projection. """ if is_image: if input is None: return None inputs = self.processor(images=[input], return_tensors="pt").to(device) inputs["pixel_values"] = inputs["pixel_values"].to(self.image_encoder.dtype) outputs = self.image_encoder(**inputs) feature = outputs.image_embeds feature = self.image_project(feature).squeeze(0) feature = (feature / feature.norm()) * normalize_constant feature = feature.unsqueeze(0) else: if input is None: return None inputs = self.tokenizer(input, return_tensors="pt", padding=True).to(device) outputs = self.text_encoder(**inputs) feature = outputs.pooler_output return feature def get_cross_attention_kwargs_with_grounded( self, hidden_size, gligen_phrases, gligen_images, gligen_boxes, input_phrases_mask, input_images_mask, repeat_batch, normalize_constant, max_objs, device, ): """ Prepare the cross-attention kwargs containing information about the grounded input (boxes, mask, image embedding, phrases embedding). """ phrases, images = gligen_phrases, gligen_images images = [None] * len(phrases) if images is None else images phrases = [None] * len(images) if phrases is None else phrases boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) phrases_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) image_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) phrases_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) image_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) text_features = [] image_features = [] for phrase, image in zip(phrases, images): text_features.append(self.get_clip_feature(phrase, normalize_constant, device, is_image=False)) image_features.append(self.get_clip_feature(image, normalize_constant, device, is_image=True)) for idx, (box, text_feature, image_feature) in enumerate(zip(gligen_boxes, text_features, image_features)): boxes[idx] = torch.tensor(box) masks[idx] = 1 if text_feature is not None: phrases_embeddings[idx] = text_feature phrases_masks[idx] = 1 if image_feature is not None: image_embeddings[idx] = image_feature image_masks[idx] = 1 input_phrases_mask = self.complete_mask(input_phrases_mask, max_objs, device) phrases_masks = phrases_masks.unsqueeze(0).repeat(repeat_batch, 1) * input_phrases_mask input_images_mask = self.complete_mask(input_images_mask, max_objs, device) image_masks = image_masks.unsqueeze(0).repeat(repeat_batch, 1) * input_images_mask boxes = boxes.unsqueeze(0).repeat(repeat_batch, 1, 1) masks = masks.unsqueeze(0).repeat(repeat_batch, 1) phrases_embeddings = phrases_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1) image_embeddings = image_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1) out = { "boxes": boxes, "masks": masks, "phrases_masks": phrases_masks, "image_masks": image_masks, "phrases_embeddings": phrases_embeddings, "image_embeddings": image_embeddings, } return out def get_cross_attention_kwargs_without_grounded(self, hidden_size, repeat_batch, max_objs, device): """ Prepare the cross-attention kwargs without information about the grounded input (boxes, mask, image embedding, phrases embedding) (All are zero tensor). """ boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) phrases_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) image_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) phrases_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) image_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) out = { "boxes": boxes.unsqueeze(0).repeat(repeat_batch, 1, 1), "masks": masks.unsqueeze(0).repeat(repeat_batch, 1), "phrases_masks": phrases_masks.unsqueeze(0).repeat(repeat_batch, 1), "image_masks": image_masks.unsqueeze(0).repeat(repeat_batch, 1), "phrases_embeddings": phrases_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1), "image_embeddings": image_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1), } return out @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, gligen_scheduled_sampling_beta: float = 0.3, gligen_phrases: List[str] = None, gligen_images: List[PIL.Image.Image] = None, input_phrases_mask: Union[int, List[int]] = None, input_images_mask: Union[int, List[int]] = None, gligen_boxes: List[List[float]] = None, gligen_inpaint_image: Optional[PIL.Image.Image] = None, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, gligen_normalize_constant: float = 28.7, clip_skip: int = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. gligen_phrases (`List[str]`): The phrases to guide what to include in each of the regions defined by the corresponding `gligen_boxes`. There should only be one phrase per bounding box. gligen_images (`List[PIL.Image.Image]`): The images to guide what to include in each of the regions defined by the corresponding `gligen_boxes`. There should only be one image per bounding box input_phrases_mask (`int` or `List[int]`): pre phrases mask input defined by the correspongding `input_phrases_mask` input_images_mask (`int` or `List[int]`): pre images mask input defined by the correspongding `input_images_mask` gligen_boxes (`List[List[float]]`): The bounding boxes that identify rectangular regions of the image that are going to be filled with the content described by the corresponding `gligen_phrases`. Each rectangular box is defined as a `List[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1]. gligen_inpaint_image (`PIL.Image.Image`, *optional*): The input image, if provided, is inpainted with objects described by the `gligen_boxes` and `gligen_phrases`. Otherwise, it is treated as a generation task on a blank input image. gligen_scheduled_sampling_beta (`float`, defaults to 0.3): Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image Generation](https://arxiv.org/pdf/2301.07093.pdf). Scheduled Sampling factor is only varied for scheduled sampling during inference for improved quality and controllability. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). gligen_normalize_constant (`float`, *optional*, defaults to 28.7): The normalize value of the image embedding. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clip_skip=clip_skip, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 5.1 Prepare GLIGEN variables max_objs = 30 if len(gligen_boxes) > max_objs: warnings.warn( f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.", FutureWarning, ) gligen_phrases = gligen_phrases[:max_objs] gligen_boxes = gligen_boxes[:max_objs] gligen_images = gligen_images[:max_objs] repeat_batch = batch_size * num_images_per_prompt if do_classifier_free_guidance: repeat_batch = repeat_batch * 2 if cross_attention_kwargs is None: cross_attention_kwargs = {} hidden_size = prompt_embeds.shape[2] cross_attention_kwargs["gligen"] = self.get_cross_attention_kwargs_with_grounded( hidden_size=hidden_size, gligen_phrases=gligen_phrases, gligen_images=gligen_images, gligen_boxes=gligen_boxes, input_phrases_mask=input_phrases_mask, input_images_mask=input_images_mask, repeat_batch=repeat_batch, normalize_constant=gligen_normalize_constant, max_objs=max_objs, device=device, ) cross_attention_kwargs_without_grounded = {} cross_attention_kwargs_without_grounded["gligen"] = self.get_cross_attention_kwargs_without_grounded( hidden_size=hidden_size, repeat_batch=repeat_batch, max_objs=max_objs, device=device ) # Prepare latent variables for GLIGEN inpainting if gligen_inpaint_image is not None: # if the given input image is not of the same size as expected by VAE # center crop and resize the input image to expected shape if gligen_inpaint_image.size != (self.vae.sample_size, self.vae.sample_size): gligen_inpaint_image = self.target_size_center_crop(gligen_inpaint_image, self.vae.sample_size) # Convert a single image into a batch of images with a batch size of 1 # The resulting shape becomes (1, C, H, W), where C is the number of channels, # and H and W are the height and width of the image. # scales the pixel values to a range [-1, 1] gligen_inpaint_image = self.image_processor.preprocess(gligen_inpaint_image) gligen_inpaint_image = gligen_inpaint_image.to(dtype=self.vae.dtype, device=self.vae.device) # Run AutoEncoder to get corresponding latents gligen_inpaint_latent = self.vae.encode(gligen_inpaint_image).latent_dist.sample() gligen_inpaint_latent = self.vae.config.scaling_factor * gligen_inpaint_latent # Generate an inpainting mask # pixel value = 0, where the object is present (defined by bounding boxes above) # 1, everywhere else gligen_inpaint_mask = self.draw_inpaint_mask_from_boxes(gligen_boxes, gligen_inpaint_latent.shape[2:]) gligen_inpaint_mask = gligen_inpaint_mask.to( dtype=gligen_inpaint_latent.dtype, device=gligen_inpaint_latent.device ) gligen_inpaint_mask = gligen_inpaint_mask[None, None] gligen_inpaint_mask_addition = torch.cat( (gligen_inpaint_latent * gligen_inpaint_mask, gligen_inpaint_mask), dim=1 ) # Convert a single mask into a batch of masks with a batch size of 1 gligen_inpaint_mask_addition = gligen_inpaint_mask_addition.expand(repeat_batch, -1, -1, -1).clone() int(gligen_scheduled_sampling_beta * len(timesteps)) self.enable_fuser(True) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if latents.shape[1] != 4: latents = torch.randn_like(latents[:, :4]) if gligen_inpaint_image is not None: gligen_inpaint_latent_with_noise = ( self.scheduler.add_noise( gligen_inpaint_latent, torch.randn_like(gligen_inpaint_latent), torch.tensor([t]) ) .expand(latents.shape[0], -1, -1, -1) .clone() ) latents = gligen_inpaint_latent_with_noise * gligen_inpaint_mask + latents * ( 1 - gligen_inpaint_mask ) # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if gligen_inpaint_image is not None: latent_model_input = torch.cat((latent_model_input, gligen_inpaint_mask_addition), dim=1) # predict the noise residual with grounded information noise_pred_with_grounding = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # predict the noise residual without grounded information noise_pred_without_grounding = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs_without_grounded, ).sample # perform guidance if do_classifier_free_guidance: # Using noise_pred_text from noise residual with grounded information and noise_pred_uncond from noise residual without grounded information _, noise_pred_text = noise_pred_with_grounding.chunk(2) noise_pred_uncond, _ = noise_pred_without_grounding.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) else: noise_pred = noise_pred_with_grounding # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_k_diffusion_available, is_k_diffusion_version, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not ( is_transformers_available() and is_torch_available() and is_k_diffusion_available() and is_k_diffusion_version(">=", "0.0.12") ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) else: _import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"] _import_structure["pipeline_stable_diffusion_xl_k_diffusion"] = ["StableDiffusionXLKDiffusionPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not ( is_transformers_available() and is_torch_available() and is_k_diffusion_available() and is_k_diffusion_version(">=", "0.0.12") ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * else: from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline from .pipeline_stable_diffusion_xl_k_diffusion import StableDiffusionXLKDiffusionPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect from typing import Callable, List, Optional, Union import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import LMSDiscreteScheduler from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class ModelWrapper: def __init__(self, model, alphas_cumprod): self.model = model self.alphas_cumprod = alphas_cumprod def apply_model(self, *args, **kwargs): if len(args) == 3: encoder_hidden_states = args[-1] args = args[:2] if kwargs.get("cond", None) is not None: encoder_hidden_states = kwargs.pop("cond") return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample class StableDiffusionKDiffusionPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights This is an experimental pipeline and is likely to change in the future. Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker: bool = True, ): super().__init__() logger.info( f"{self.__class__} is an experimntal pipeline and is likely to change in the future. We recommend to use" " this pipeline for fast experimentation / iteration if needed, but advice to rely on existing pipelines" " as defined in https://huggingface.co/docs/diffusers/api/schedulers#implemented-schedulers for" " production settings." ) # get correct sigmas from LMS scheduler = LMSDiscreteScheduler.from_config(scheduler.config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) model = ModelWrapper(unet, scheduler.alphas_cumprod) if scheduler.config.prediction_type == "v_prediction": self.k_diffusion_model = CompVisVDenoiser(model) else: self.k_diffusion_model = CompVisDenoiser(model) def set_scheduler(self, scheduler_type: str): library = importlib.import_module("k_diffusion") sampling = getattr(library, "sampling") try: self.sampler = getattr(sampling, scheduler_type) except Exception: valid_samplers = [] for s in dir(sampling): if "sample_" in s: valid_samplers.append(s) raise ValueError(f"Invalid scheduler type {scheduler_type}. Please choose one of {valid_samplers}.") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, use_karras_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, clip_skip: int = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M Karras`. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed will be generated. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = True if guidance_scale <= 1.0: raise ValueError("has to use guidance_scale") # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) # 5. Prepare sigmas if use_karras_sigmas: sigma_min: float = self.k_diffusion_model.sigmas[0].item() sigma_max: float = self.k_diffusion_model.sigmas[-1].item() sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) sigmas = sigmas.to(device) else: sigmas = self.scheduler.sigmas sigmas = sigmas.to(prompt_embeds.dtype) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) latents = latents * sigmas[0] self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) # 7. Define model function def model_fn(x, t): latent_model_input = torch.cat([x] * 2) t = torch.cat([t] * 2) noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) return noise_pred # 8. Run k-diffusion solver sampler_kwargs = {} if "noise_sampler" in inspect.signature(self.sampler).parameters: min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) sampler_kwargs["noise_sampler"] = noise_sampler if "generator" in inspect.signature(self.sampler).parameters: sampler_kwargs["generator"] = generator latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect from typing import List, Optional, Tuple, Union import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras from transformers import ( CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, ) from ...image_processor import VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionXLKDiffusionPipeline >>> pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> pipe.set_scheduler("sample_dpmpp_2m_sde") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.ModelWrapper class ModelWrapper: def __init__(self, model, alphas_cumprod): self.model = model self.alphas_cumprod = alphas_cumprod def apply_model(self, *args, **kwargs): if len(args) == 3: encoder_hidden_states = args[-1] args = args[:2] if kwargs.get("cond", None) is not None: encoder_hidden_states = kwargs.pop("cond") return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample class StableDiffusionXLKDiffusionPipeline( DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, IPAdapterMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL and k-diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "feature_extractor", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, ): super().__init__() # get correct sigmas from LMS scheduler = LMSDiscreteScheduler.from_config(scheduler.config) self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size model = ModelWrapper(unet, scheduler.alphas_cumprod) if scheduler.config.prediction_type == "v_prediction": self.k_diffusion_model = CompVisVDenoiser(model) else: self.k_diffusion_model = CompVisDenoiser(model) # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.set_scheduler def set_scheduler(self, scheduler_type: str): library = importlib.import_module("k_diffusion") sampling = getattr(library, "sampling") try: self.sampler = getattr(sampling, scheduler_type) except Exception: valid_samplers = [] for s in dir(sampling): if "sample_" in s: valid_samplers.append(s) raise ValueError(f"Invalid scheduler type {scheduler_type}. Please choose one of {valid_samplers}.") # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def check_inputs( self, prompt, prompt_2, height, width, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, use_karras_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, clip_skip: Optional[int] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. Examples: Returns: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) if guidance_scale <= 1.0: raise ValueError("has to use guidance_scale") self._guidance_scale = guidance_scale self._clip_skip = clip_skip # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = None ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) # 5. Prepare sigmas if use_karras_sigmas: sigma_min: float = self.k_diffusion_model.sigmas[0].item() sigma_max: float = self.k_diffusion_model.sigmas[-1].item() sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) else: sigmas = self.scheduler.sigmas sigmas = sigmas.to(dtype=prompt_embeds.dtype, device=device) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) latents = latents * sigmas[0] self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # 8. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 9. Define model function def model_fn(x, t): latent_model_input = torch.cat([x] * 2) t = torch.cat([t] * 2) noise_pred = self.k_diffusion_model( latent_model_input, t, cond=prompt_embeds, timestep_cond=timestep_cond, added_cond_kwargs=added_cond_kwargs, ) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) return noise_pred # 10. Run k-diffusion solver sampler_kwargs = {} if "noise_sampler" in inspect.signature(self.sampler).parameters: min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) sampler_kwargs["noise_sampler"] = noise_sampler if "generator" in inspect.signature(self.sampler).parameters: sampler_kwargs["generator"] = generator latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_ldm3d/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py ================================================ # Copyright 2024 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessorLDM3D from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, BaseOutput, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```python >>> from diffusers import StableDiffusionLDM3DPipeline >>> pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c") >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> output = pipe(prompt) >>> rgb_image, depth_image = output.rgb, output.depth >>> rgb_image[0].save("astronaut_ldm3d_rgb.jpg") >>> depth_image[0].save("astronaut_ldm3d_depth.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps @dataclass class LDM3DPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: rgb (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. depth (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. nsfw_content_detected (`List[bool]`) List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or `None` if safety checking could not be performed. """ rgb: Union[List[PIL.Image.Image], np.ndarray] depth: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] class StableDiffusionLDM3DPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image and 3D generation using LDM3D. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: Optional[CLIPVisionModelWithProjection], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) rgb_feature_extractor_input = feature_extractor_input[0] safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 49, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 5.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return ((rgb, depth), has_nsfw_concept) return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_panorama/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py ================================================ # Copyright 2024 MultiDiffusion Authors and The HuggingFace Team. All rights reserved." # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMScheduler from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler >>> model_ckpt = "stabilityai/stable-diffusion-2-base" >>> scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler") >>> pipe = StableDiffusionPanoramaPipeline.from_pretrained( ... model_ckpt, scheduler=scheduler, torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of the dolomites" >>> image = pipe(prompt).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionPanoramaPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, ): r""" Pipeline for text-to-image generation using MultiDiffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def decode_latents_with_padding(self, latents: torch.Tensor, padding: int = 8) -> torch.Tensor: """ Decode the given latents with padding for circular inference. Args: latents (torch.Tensor): The input latents to decode. padding (int, optional): The number of latents to add on each side for padding. Defaults to 8. Returns: torch.Tensor: The decoded image with padding removed. Notes: - The padding is added to remove boundary artifacts and improve the output quality. - This would slightly increase the memory usage. - The padding pixels are then removed from the decoded image. """ latents = 1 / self.vae.config.scaling_factor * latents latents_left = latents[..., :padding] latents_right = latents[..., -padding:] latents = torch.cat((latents_right, latents, latents_left), axis=-1) image = self.vae.decode(latents, return_dict=False)[0] padding_pix = self.vae_scale_factor * padding image = image[..., padding_pix:-padding_pix] return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb def get_views( self, panorama_height: int, panorama_width: int, window_size: int = 64, stride: int = 8, circular_padding: bool = False, ) -> List[Tuple[int, int, int, int]]: """ Generates a list of views based on the given parameters. Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113). If panorama's height/width < window_size, num_blocks of height/width should return 1. Args: panorama_height (int): The height of the panorama. panorama_width (int): The width of the panorama. window_size (int, optional): The size of the window. Defaults to 64. stride (int, optional): The stride value. Defaults to 8. circular_padding (bool, optional): Whether to apply circular padding. Defaults to False. Returns: List[Tuple[int, int, int, int]]: A list of tuples representing the views. Each tuple contains four integers representing the start and end coordinates of the window in the panorama. """ panorama_height /= 8 panorama_width /= 8 num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1 if circular_padding: num_blocks_width = panorama_width // stride if panorama_width > window_size else 1 else: num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1 total_num_blocks = int(num_blocks_height * num_blocks_width) views = [] for i in range(total_num_blocks): h_start = int((i // num_blocks_width) * stride) h_end = h_start + window_size w_start = int((i % num_blocks_width) * stride) w_end = w_start + window_size views.append((h_start, h_end, w_start, w_end)) return views @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def clip_skip(self): return self._clip_skip @property def do_classifier_free_guidance(self): return False @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = 512, width: Optional[int] = 2048, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 7.5, view_batch_size: int = 1, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, circular_padding: bool = False, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs: Any, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 2048): The width in pixels of the generated image. The width is kept high because the pipeline is supposed generate panorama-like images. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): The timesteps at which to generate the images. If not specified, then the default timestep spacing strategy of the scheduler is used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. view_batch_size (`int`, *optional*, defaults to 1): The batch size to denoise split views. For some GPUs with high performance, higher view batch size can speedup the generation and increase the VRAM usage. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): A rescaling factor for the guidance embeddings. A value of 0.0 means no rescaling is applied. circular_padding (`bool`, *optional*, defaults to `False`): If set to `True`, circular padding is applied to ensure there are no stitching artifacts. Circular padding allows the model to seamlessly generate a transition from the rightmost part of the image to the leftmost part, maintaining consistency in a 360-degree sense. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List[str]`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Define panorama grid and initialize views for synthesis. # prepare batch grid views = self.get_views(height, width, circular_padding=circular_padding) views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch) count = torch.zeros_like(latents) value = torch.zeros_like(latents) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 8. Denoising loop # Each denoising step also includes refinement of the latents with respect to the # views. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue count.zero_() value.zero_() # generate views # Here, we iterate through different spatial crops of the latents and denoise them. These # denoised (latent) crops are then averaged to produce the final latent # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 # Batch views denoise for j, batch_view in enumerate(views_batch): vb_size = len(batch_view) # get the latents corresponding to the current view coordinates if circular_padding: latents_for_view = [] for h_start, h_end, w_start, w_end in batch_view: if w_end > latents.shape[3]: # Add circular horizontal padding latent_view = torch.cat( ( latents[:, :, h_start:h_end, w_start:], latents[:, :, h_start:h_end, : w_end - latents.shape[3]], ), axis=-1, ) else: latent_view = latents[:, :, h_start:h_end, w_start:w_end] latents_for_view.append(latent_view) latents_for_view = torch.cat(latents_for_view) else: latents_for_view = torch.cat( [ latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view ] ) # rematch block's scheduler status self.scheduler.__dict__.update(views_scheduler_status[j]) # expand the latents if we are doing classifier free guidance latent_model_input = ( latents_for_view.repeat_interleave(2, dim=0) if do_classifier_free_guidance else latents_for_view ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # repeat prompt_embeds for batch prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds_input, timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg( noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale ) # compute the previous noisy sample x_t -> x_t-1 latents_denoised_batch = self.scheduler.step( noise_pred, t, latents_for_view, **extra_step_kwargs ).prev_sample # save views scheduler status after sample views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) # extract value from batch for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( latents_denoised_batch.chunk(vb_size), batch_view ): if circular_padding and w_end > latents.shape[3]: # Case for circular padding value[:, :, h_start:h_end, w_start:] += latents_view_denoised[ :, :, h_start:h_end, : latents.shape[3] - w_start ] value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[ :, :, h_start:h_end, latents.shape[3] - w_start : ] count[:, :, h_start:h_end, w_start:] += 1 count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1 else: value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = torch.where(count > 0, value / count, value) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if output_type != "latent": if circular_padding: image = self.decode_latents_with_padding(latents) else: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_safe/__init__.py ================================================ from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import ( DIFFUSERS_SLOW_IMPORT, BaseOutput, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) @dataclass class SafetyConfig(object): WEAK = { "sld_warmup_steps": 15, "sld_guidance_scale": 20, "sld_threshold": 0.0, "sld_momentum_scale": 0.0, "sld_mom_beta": 0.0, } MEDIUM = { "sld_warmup_steps": 10, "sld_guidance_scale": 1000, "sld_threshold": 0.01, "sld_momentum_scale": 0.3, "sld_mom_beta": 0.4, } STRONG = { "sld_warmup_steps": 7, "sld_guidance_scale": 2000, "sld_threshold": 0.025, "sld_momentum_scale": 0.5, "sld_mom_beta": 0.7, } MAX = { "sld_warmup_steps": 0, "sld_guidance_scale": 5000, "sld_threshold": 1.0, "sld_momentum_scale": 0.5, "sld_mom_beta": 0.7, } _dummy_objects = {} _additional_imports = {} _import_structure = {} _additional_imports.update({"SafetyConfig": SafetyConfig}) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure.update( { "pipeline_output": ["StableDiffusionSafePipelineOutput"], "pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"], "safety_checker": ["StableDiffusionSafetyChecker"], } ) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_output import StableDiffusionSafePipelineOutput from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe from .safety_checker import SafeStableDiffusionSafetyChecker else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) for name, value in _additional_imports.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_safe/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL.Image from ...utils import ( BaseOutput, ) @dataclass class StableDiffusionSafePipelineOutput(BaseOutput): """ Output class for Safe Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" (nsfw) content, or `None` if no safety check was performed or no images were flagged. applied_safety_concept (`str`) The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] applied_safety_concept: Optional[str] ================================================ FILE: src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py ================================================ import inspect import warnings from typing import Callable, List, Optional, Union import numpy as np import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput from ...loaders import IPAdapterMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionSafePipelineOutput from .safety_checker import SafeStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAdapterMixin): r""" Pipeline based on the [`StableDiffusionPipeline`] for text-to-image generation using Safe Latent Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: SafeStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() safety_concept: Optional[str] = ( "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity," " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child" " abuse, brutality, cruelty" ) if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self._safety_text_concept = safety_concept self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) @property def safety_concept(self): r""" Getter method for the safety concept used with SLD Returns: `str`: The text describing the safety concept """ return self._safety_text_concept @safety_concept.setter def safety_concept(self, concept): r""" Setter method for the safety concept used with SLD Args: concept (`str`): The text of the new safety concept """ self._safety_text_concept = concept def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids if not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # Encode the safety concept text if enable_safety_guidance: safety_concept_input = self.tokenizer( [self._safety_text_concept], padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0] # duplicate safety embeddings for each generation per prompt, using mps friendly method seq_len = safety_embeddings.shape[1] safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1) safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance + sld, we need to do three forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing three forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings]) else: # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def run_safety_checker(self, image, device, dtype, enable_safety_guidance): if self.safety_checker is not None: images = image.copy() safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) flagged_images = np.zeros((2, *image.shape[1:])) if any(has_nsfw_concept): logger.warning( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead." f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}" ) for idx, has_nsfw_concept in enumerate(has_nsfw_concept): if has_nsfw_concept: flagged_images[idx] = images[idx] image[idx] = np.zeros(image[idx].shape) # black image else: has_nsfw_concept = None flagged_images = None return image, has_nsfw_concept, flagged_images # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def perform_safety_guidance( self, enable_safety_guidance, safety_momentum, noise_guidance, noise_pred_out, i, sld_guidance_scale, sld_warmup_steps, sld_threshold, sld_momentum_scale, sld_mom_beta, ): # Perform SLD guidance if enable_safety_guidance: if safety_momentum is None: safety_momentum = torch.zeros_like(noise_guidance) noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1] noise_pred_safety_concept = noise_pred_out[2] # Equation 6 scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0) # Equation 6 safety_concept_scale = torch.where( (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale ) # Equation 4 noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale) # Equation 7 noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum # Equation 8 safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety if i >= sld_warmup_steps: # Warmup # Equation 3 noise_guidance = noise_guidance - noise_guidance_safety return noise_guidance, safety_momentum # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, sld_guidance_scale: Optional[float] = 1000, sld_warmup_steps: Optional[int] = 10, sld_threshold: Optional[float] = 0.01, sld_momentum_scale: Optional[float] = 0.3, sld_mom_beta: Optional[float] = 0.4, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. sld_guidance_scale (`float`, *optional*, defaults to 1000): If `sld_guidance_scale < 1`, safety guidance is disabled. sld_warmup_steps (`int`, *optional*, defaults to 10): Number of warmup steps for safety guidance. SLD is only be applied for diffusion steps greater than `sld_warmup_steps`. sld_threshold (`float`, *optional*, defaults to 0.01): Threshold that separates the hyperplane between appropriate and inappropriate images. sld_momentum_scale (`float`, *optional*, defaults to 0.3): Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0, momentum is disabled. Momentum is built up during warmup for diffusion steps smaller than `sld_warmup_steps`. sld_mom_beta (`float`, *optional*, defaults to 0.4): Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous momentum is kept. Momentum is built up during warmup for diffusion steps smaller than `sld_warmup_steps`. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. Examples: ```py import torch from diffusers import StableDiffusionPipelineSafe from diffusers.pipelines.stable_diffusion_safe import SafetyConfig pipeline = StableDiffusionPipelineSafe.from_pretrained( "AIML-TUDA/stable-diffusion-safe", torch_dtype=torch.float16 ).to("cuda") prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker" image = pipeline(prompt=prompt, **SafetyConfig.MEDIUM).images[0] ``` """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance if not enable_safety_guidance: warnings.warn("Safety checker disabled!") if ip_adapter_image is not None: output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, num_images_per_prompt, output_hidden_state ) if do_classifier_free_guidance: if enable_safety_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds, image_embeds]) else: image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None safety_momentum = None num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * (3 if enable_safety_guidance else 2)) if do_classifier_free_guidance else latents ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs ).sample # perform guidance if do_classifier_free_guidance: noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] # default classifier free guidance noise_guidance = noise_pred_text - noise_pred_uncond # Perform SLD guidance if enable_safety_guidance: if safety_momentum is None: safety_momentum = torch.zeros_like(noise_guidance) noise_pred_safety_concept = noise_pred_out[2] # Equation 6 scale = torch.clamp( torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 ) # Equation 6 safety_concept_scale = torch.where( (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale, ) # Equation 4 noise_guidance_safety = torch.mul( (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale ) # Equation 7 noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum # Equation 8 safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety if i >= sld_warmup_steps: # Warmup # Equation 3 noise_guidance = noise_guidance - noise_guidance_safety noise_pred = noise_pred_uncond + guidance_scale * noise_guidance # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 8. Post-processing image = self.decode_latents(latents) # 9. Run safety checker image, has_nsfw_concept, flagged_images = self.run_safety_checker( image, device, prompt_embeds.dtype, enable_safety_guidance ) # 10. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) if flagged_images is not None: flagged_images = self.numpy_to_pil(flagged_images) if not return_dict: return ( image, has_nsfw_concept, self._safety_text_concept if enable_safety_guidance else None, flagged_images, ) return StableDiffusionSafePipelineOutput( images=image, nsfw_content_detected=has_nsfw_concept, applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None, unsafe_images=flagged_images, ) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel from ...utils import logging logger = logging.get_logger(__name__) def cosine_distance(image_embeds, text_embeds): normalized_image_embeds = nn.functional.normalize(image_embeds) normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) class SafeStableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPConfig): super().__init__(config) self.vision_model = CLIPVisionModel(config.vision_config) self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) @torch.no_grad() def forward(self, clip_input, images): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() result = [] batch_size = image_embeds.shape[0] for i in range(batch_size): result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} # increase this value to create a stronger `nfsw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 for concept_idx in range(len(special_cos_dist[0])): concept_cos = special_cos_dist[i][concept_idx] concept_threshold = self.special_care_embeds_weights[concept_idx].item() result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["special_scores"][concept_idx] > 0: result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) adjustment = 0.01 for concept_idx in range(len(cos_dist[0])): concept_cos = cos_dist[i][concept_idx] concept_threshold = self.concept_embeds_weights[concept_idx].item() result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["concept_scores"][concept_idx] > 0: result_img["bad_concepts"].append(concept_idx) result.append(result_img) has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] return images, has_nsfw_concepts @torch.no_grad() def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) cos_dist = cosine_distance(image_embeds, self.concept_embeds) # increase this value to create a stronger `nsfw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment # special_scores = special_scores.round(decimals=3) special_care = torch.any(special_scores > 0, dim=1) special_adjustment = special_care * 0.01 special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment # concept_scores = concept_scores.round(decimals=3) has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) return images, has_nsfw_concepts ================================================ FILE: src/diffusers/pipelines/stable_diffusion_sag/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py ================================================ # Copyright 2024 Susung Hong and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionSAGPipeline >>> pipe = StableDiffusionSAGPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, sag_scale=0.75).images[0] ``` """ # processes and stores attention probabilities class CrossAttnStoreProcessor: def __init__(self): self.attention_probs = None def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, ): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) self.attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(self.attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states # Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: Optional[CLIPVisionModelWithProjection] = None, requires_safety_checker: bool = True, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) image_embeds = [] for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) single_negative_image_embeds = torch.stack( [single_negative_image_embeds] * num_images_per_prompt, dim=0 ) if do_classifier_free_guidance: single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = single_image_embeds.to(device) image_embeds.append(single_image_embeds) else: image_embeds = ip_adapter_image_embeds return image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, sag_scale: float = 0.75, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. sag_scale (`float`, *optional*, defaults to 0.75): Chosen between [0, 1.0] for better quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # and `sag_scale` is` `s` of equation (16) # of the self-attention guidance paper: https://arxiv.org/pdf/2210.00939.pdf # `sag_scale = 0` means no self-attention guidance do_self_attention_guidance = sag_scale > 0.0 if ip_adapter_image is not None or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, do_classifier_free_guidance, ) if do_classifier_free_guidance: image_embeds = [] negative_image_embeds = [] for tmp_image_embeds in ip_adapter_image_embeds: single_negative_image_embeds, single_image_embeds = tmp_image_embeds.chunk(2) image_embeds.append(single_image_embeds) negative_image_embeds.append(single_negative_image_embeds) else: image_embeds = ip_adapter_image_embeds # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps if timesteps.dtype not in [torch.int16, torch.int32, torch.int64]: raise ValueError( f"{self.__class__.__name__} does not support using a scheduler of type {self.scheduler.__class__.__name__}. Please make sure to use one of 'DDIMScheduler, PNDMScheduler, DDPMScheduler, DEISMultistepScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler'." ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) if do_classifier_free_guidance: added_uncond_kwargs = ( {"image_embeds": negative_image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 7. Denoising loop original_attn_proc = self.unet.attn_processors store_processor = CrossAttnStoreProcessor() self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order map_size = None def get_map_size(module, input, output): nonlocal map_size map_size = output[0].shape[-2:] with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # perform self-attention guidance with the stored self-attention map if do_self_attention_guidance: # classifier-free guidance produces two chunks of attention map # and we only use unconditional one according to equation (25) # in https://arxiv.org/pdf/2210.00939.pdf if do_classifier_free_guidance: # DDIM-like prediction of x0 pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) # get the stored attention maps uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) # self-attention-based degrading of latents degraded_latents = self.sag_masking( pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t) ) uncond_emb, _ = prompt_embeds.chunk(2) # forward and give guidance degraded_pred = self.unet( degraded_latents, t, encoder_hidden_states=uncond_emb, added_cond_kwargs=added_uncond_kwargs, ).sample noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) else: # DDIM-like prediction of x0 pred_x0 = self.pred_x0(latents, noise_pred, t) # get the stored attention maps cond_attn = store_processor.attention_probs # self-attention-based degrading of latents degraded_latents = self.sag_masking( pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t) ) # forward and give guidance degraded_pred = self.unet( degraded_latents, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, ).sample noise_pred += sag_scale * (noise_pred - degraded_pred) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) self.maybe_free_model_hooks() # make sure to set the original attention processors back self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) def sag_masking(self, original_latents, attn_map, map_size, t, eps): # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf bh, hw1, hw2 = attn_map.shape b, latent_channel, latent_h, latent_w = original_latents.shape h = self.unet.config.attention_head_dim if isinstance(h, list): h = h[-1] # Produce attention mask attn_map = attn_map.reshape(b, h, hw1, hw2) attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 attn_mask = ( attn_mask.reshape(b, map_size[0], map_size[1]) .unsqueeze(1) .repeat(1, latent_channel, 1, 1) .type(attn_map.dtype) ) attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) # Blur according to the self-attention mask degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) # Noise it again to match the noise level degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t[None]) return degraded_latents # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) def pred_x0(self, sample, model_output, timestep): alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device) beta_prod_t = 1 - alpha_prod_t if self.scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif self.scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output # predict V model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," " or `v_prediction`" ) return pred_original_sample def pred_epsilon(self, sample, model_output, timestep): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t if self.scheduler.config.prediction_type == "epsilon": pred_eps = model_output elif self.scheduler.config.prediction_type == "sample": pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) elif self.scheduler.config.prediction_type == "v_prediction": pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," " or `v_prediction`" ) return pred_eps # Gaussian blur def gaussian_blur_2d(img, kernel_size, sigma): ksize_half = (kernel_size - 1) * 0.5 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) pdf = torch.exp(-0.5 * (x / sigma).pow(2)) x_kernel = pdf / pdf.sum() x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] img = F.pad(img, padding, mode="reflect") img = F.conv2d(img, kernel2d, groups=img.shape[-3]) return img ================================================ FILE: src/diffusers/pipelines/stable_diffusion_xl/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_flax_available, is_torch_available, is_transformers_available, ) _dummy_objects = {} _additional_imports = {} _import_structure = {"pipeline_output": ["StableDiffusionXLPipelineOutput"]} if is_transformers_available() and is_flax_available(): _import_structure["pipeline_output"].extend(["FlaxStableDiffusionXLPipelineOutput"]) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_stable_diffusion_xl"] = ["StableDiffusionXLPipeline"] _import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"] _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] if is_transformers_available() and is_flax_available(): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState _additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState}) _import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline try: if not (is_transformers_available() and is_flax_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_flax_objects import * else: from .pipeline_flax_stable_diffusion_xl import ( FlaxStableDiffusionXLPipeline, ) from .pipeline_output import FlaxStableDiffusionXLPipelineOutput else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) for name, value in _additional_imports.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from transformers import CLIPTokenizer, FlaxCLIPTextModel from diffusers.utils import logging from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ..pipeline_flax_utils import FlaxDiffusionPipeline from .pipeline_output import FlaxStableDiffusionXLPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): def __init__( self, text_encoder: FlaxCLIPTextModel, text_encoder_2: FlaxCLIPTextModel, vae: FlaxAutoencoderKL, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: FlaxUNet2DConditionModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") # Assume we have the two encoders inputs = [] for tokenizer in [self.tokenizer, self.tokenizer_2]: text_inputs = tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) inputs.append(text_inputs.input_ids) inputs = jnp.stack(inputs, axis=1) return inputs def __call__( self, prompt_ids: jax.Array, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int = 50, guidance_scale: Union[float, jax.Array] = 7.5, height: Optional[int] = None, width: Optional[int] = None, latents: jnp.array = None, neg_prompt_ids: jnp.array = None, return_dict: bool = True, output_type: str = None, jit: bool = False, ): # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(guidance_scale, float) and jit: # Convert to a tensor so each device gets a copy. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) guidance_scale = guidance_scale[:, None] return_latents = output_type == "latent" if jit: images = _p_generate( self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, return_latents, ) else: images = self._generate( prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, return_latents, ) if not return_dict: return (images,) return FlaxStableDiffusionXLPipelineOutput(images=images) def get_embeddings(self, prompt_ids: jnp.array, params): # We assume we have the two encoders # bs, encoder_input, seq_length te_1_inputs = prompt_ids[:, 0, :] te_2_inputs = prompt_ids[:, 1, :] prompt_embeds = self.text_encoder(te_1_inputs, params=params["text_encoder"], output_hidden_states=True) prompt_embeds = prompt_embeds["hidden_states"][-2] prompt_embeds_2_out = self.text_encoder_2( te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True ) prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2] text_embeds = prompt_embeds_2_out["text_embeds"] prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) return prompt_embeds, text_embeds def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, bs, dtype): add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = jnp.array([add_time_ids] * bs, dtype=dtype) return add_time_ids def _generate( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.Array, num_inference_steps: int, height: int, width: int, guidance_scale: float, latents: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.array] = None, return_latents=False, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # Encode input prompt prompt_embeds, pooled_embeds = self.get_embeddings(prompt_ids, params) # Get unconditional embeddings batch_size = prompt_embeds.shape[0] if neg_prompt_ids is None: neg_prompt_embeds = jnp.zeros_like(prompt_embeds) negative_pooled_embeds = jnp.zeros_like(pooled_embeds) else: neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params) add_time_ids = self._get_add_time_ids( (height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype ) prompt_embeds = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048) add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) # Ensure model output will be `float32` before going into the scheduler guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) # Create random latents latents_shape = ( batch_size, self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # Prepare scheduler state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * scheduler_state.init_noise_sigma added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # Denoising loop def loop_body(step, args): latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state if DEBUG: # run with python for loop for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) if return_latents: return latents # Decode latents latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image # Static argnums are pipe, num_inference_steps, height, width, return_latents. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0, None), static_broadcasted_argnums=(0, 4, 5, 6, 10), ) def _p_generate( pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, return_latents, ): return pipe._generate( prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, return_latents, ) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Union import numpy as np import PIL.Image from ...utils import BaseOutput, is_flax_available @dataclass class StableDiffusionXLPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ images: Union[List[PIL.Image.Image], np.ndarray] if is_flax_available(): import flax @flax.struct.dataclass class FlaxStableDiffusionXLPipelineOutput(BaseOutput): """ Output class for Flax Stable Diffusion XL pipelines. Args: images (`np.ndarray`) Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. """ images: np.ndarray ================================================ FILE: src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionXLPipeline >>> pipe = StableDiffusionXLPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionXLPipeline( DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, IPAdapterMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, prompt_2, height, width, callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 8.1 Apply denoising_end if ( self.denoising_end is not None and isinstance(self.denoising_end, float) and self.denoising_end > 0 and self.denoising_end < 1 ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents if not output_type == "latent": # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL.Image import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionXLImg2ImgPipeline >>> from diffusers.utils import load_image >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" >>> init_image = load_image(url).convert("RGB") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, image=init_image).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionXLImg2ImgPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the config of `stabilityai/stable-diffusion-xl-refiner-1-0`. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "add_neg_time_ids", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, image_encoder=image_encoder, feature_extractor=feature_extractor, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, prompt_2, strength, num_inference_steps, callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if num_inference_steps is None: raise ValueError("`num_inference_steps` cannot be None.") elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: raise ValueError( f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" f" {type(num_inference_steps)}." ) if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) else: t_start = 0 timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] # Strength is irrelevant if we directly request a timestep to start at; # that is, strength is determined by the denoising_start instead. if denoising_start is not None: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() if self.scheduler.order == 2 and num_inference_steps % 2 == 0: # if the scheduler is a 2nd order scheduler we might have to do +1 # because `num_inference_steps` might be even given that every timestep # (except the highest one) is duplicated. If `num_inference_steps` is even it would # mean that we cut the timesteps in the middle of the denoising step # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler num_inference_steps = num_inference_steps + 1 # because t_n+1 >= t_n, we slice the timesteps starting from the end timesteps = timesteps[-num_inference_steps:] return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start def prepare_latents( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) latents_mean = latents_std = None if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") torch.cuda.empty_cache() image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.config.force_upcast: image = image.float() self.vae.to(dtype=torch.float32) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) init_latents = init_latents.to(dtype) if latents_mean is not None and latents_std is not None: latents_mean = latents_mean.to(device=device, dtype=dtype) latents_std = latents_std.to(device=device, dtype=dtype) init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std else: init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) if add_noise: shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype, text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) add_neg_time_ids = list( negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." ) elif expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) return add_time_ids, add_neg_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def denoising_start(self): return self._denoising_start @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, strength: float = 0.3, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): The image(s) to modify with the pipeline. strength (`float`, *optional*, defaults to 0.3): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of `denoising_start` being declared as an integer, the value of `strength` will be ignored. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. aesthetic_score (`float`, *optional*, defaults to 6.0): Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_aesthetic_score (`float`, *optional*, defaults to 2.5): Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, strength, num_inference_steps, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._denoising_start = denoising_start self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. Prepare timesteps def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device, denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) add_noise = True if self.denoising_start is None else False # 6. Prepare latent variables if latents is None: latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, add_noise, ) # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) height, width = latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 8. Prepare added time ids & embeddings if negative_original_size is None: negative_original_size = original_size if negative_target_size is None: negative_target_size = target_size add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 9. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 9.1 Apply denoising_end if ( self.denoising_end is not None and self.denoising_start is not None and denoising_value_valid(self.denoising_end) and denoising_value_valid(self.denoising_start) and self.denoising_start >= self.denoising_end ): raise ValueError( f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + f" {self.denoising_end} when using type float." ) elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 9.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionXLInpaintPipeline >>> from diffusers.utils import load_image >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-1.0", ... torch_dtype=torch.float16, ... variant="fp16", ... use_safetensors=True, ... ) >>> pipe.to("cuda") >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" >>> init_image = load_image(img_url).convert("RGB") >>> mask_image = load_image(mask_url).convert("RGB") >>> prompt = "A majestic tiger sitting on a bench" >>> image = pipe( ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80 ... ).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg def mask_pil_to_torch(mask, height, width): # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask = torch.from_numpy(mask) return mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionXLInpaintPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config of `stabilityai/stable-diffusion-xl-refiner-1-0`. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] _callback_tensor_inputs = [ "latents", "prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids", "negative_pooled_prompt_embeds", "add_neg_time_ids", "mask", "masked_image_latents", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, image_encoder=image_encoder, feature_extractor=feature_extractor, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, prompt_2, image, mask_image, height, width, strength, callback_steps, output_type, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, padding_mask_crop=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." ) if not isinstance(mask_image, PIL.Image.Image): raise ValueError( f"The mask image should be a PIL image when inpainting mask crop, but is of type" f" {type(mask_image)}." ) if output_type != "pil": raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None, timestep=None, is_strength_max=True, add_noise=True, return_noise=False, return_image_latents=False, ): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if (image is None or timestep is None) and not is_strength_max: raise ValueError( "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." "However, either the image or the noise timestep has not been provided." ) if image.shape[1] == 4: image_latents = image.to(device=device, dtype=dtype) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) elif return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) if latents is None and add_noise: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents elif add_noise: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma else: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = image_latents.to(device) outputs = (latents,) if return_noise: outputs += (noise,) if return_image_latents: outputs += (image_latents,) return outputs def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): dtype = image.dtype if self.vae.config.force_upcast: image = image.float() self.vae.to(dtype=torch.float32) if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) image_latents = image_latents.to(dtype) image_latents = self.vae.config.scaling_factor * image_latents return image_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask if masked_image is not None and masked_image.shape[1] == 4: masked_image_latents = masked_image else: masked_image_latents = None if masked_image is not None: if masked_image_latents is None: masked_image = masked_image.to(device=device, dtype=dtype) masked_image_latents = self._encode_vae_image(masked_image, generator=generator) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat( batch_size // masked_image_latents.shape[0], 1, 1, 1 ) masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep if denoising_start is None: init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) else: t_start = 0 timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] # Strength is irrelevant if we directly request a timestep to start at; # that is, strength is determined by the denoising_start instead. if denoising_start is not None: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() if self.scheduler.order == 2 and num_inference_steps % 2 == 0: # if the scheduler is a 2nd order scheduler we might have to do +1 # because `num_inference_steps` might be even given that every timestep # (except the highest one) is duplicated. If `num_inference_steps` is even it would # mean that we cut the timesteps in the middle of the denoising step # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler num_inference_steps = num_inference_steps + 1 # because t_n+1 >= t_n, we slice the timesteps starting from the end timesteps = timesteps[-num_inference_steps:] return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype, text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) add_neg_time_ids = list( negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." ) elif expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) return add_time_ids, add_neg_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def denoising_end(self): return self._denoising_end @property def denoising_start(self): return self._denoising_start @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, mask_image: PipelineImageInput = None, masked_image_latents: torch.Tensor = None, height: Optional[int] = None, width: Optional[int] = None, padding_mask_crop: Optional[int] = None, strength: float = 0.9999, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. padding_mask_crop (`int`, *optional*, defaults to `None`): The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large and contain information irrelevant for inpainting, such as background. strength (`float`, *optional*, defaults to 0.9999): Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked portion of the reference `image`. Note that in the case of `denoising_start` being declared as an integer, the value of `strength` will be ignored. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. aesthetic_score (`float`, *optional*, defaults to 6.0): Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_aesthetic_score (`float`, *optional*, defaults to 2.5): Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs self.check_inputs( prompt, prompt_2, image, mask_image, height, width, strength, callback_steps, output_type, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, padding_mask_crop, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._denoising_start = denoising_start self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # 4. set timesteps def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device, denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, ) # check that number of inference steps is not < 1 - as this doesn't make sense if num_inference_steps < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 # 5. Preprocess mask and image if padding_mask_crop is not None: crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) resize_mode = "fill" else: crops_coords = None resize_mode = "default" original_image = image init_image = self.image_processor.preprocess( image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode ) init_image = init_image.to(dtype=torch.float32) mask = self.mask_processor.preprocess( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) if masked_image_latents is not None: masked_image = masked_image_latents elif init_image.shape[1] == 4: # if images are in latent space, we can't mask it masked_image = None else: masked_image = init_image * (mask < 0.5) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 add_noise = True if self.denoising_start is None else False latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, add_noise=add_noise, return_noise=True, return_image_latents=return_image_latents, ) if return_image_latents: latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, self.do_classifier_free_guidance, ) # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: raise ValueError( f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." ) # 8.1 Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline height, width = latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 10. Prepare added time ids & embeddings if negative_original_size is None: negative_original_size = original_size if negative_target_size is None: negative_target_size = target_size add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 11. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) if ( self.denoising_end is not None and self.denoising_start is not None and denoising_value_valid(self.denoising_end) and denoising_value_valid(self.denoising_start) and self.denoising_start >= self.denoising_end ): raise ValueError( f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + f" {self.denoising_end} when using type float." ) elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] # 11.1 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if num_channels_unet == 9: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if num_channels_unet == 4: init_latents_proper = image_latents if self.do_classifier_free_guidance: init_mask, _ = mask.chunk(2) else: init_mask = mask if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) negative_pooled_prompt_embeds = callback_outputs.pop( "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) mask = callback_outputs.pop("mask", mask) masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: return StableDiffusionXLPipelineOutput(images=latents) # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) if padding_mask_crop is not None: image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py ================================================ # Copyright 2024 Harutatsu Akiyama and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL.Image import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, is_invisible_watermark_available, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput if is_invisible_watermark_available(): from .watermark import StableDiffusionXLWatermarker if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionXLInstructPix2PixPipeline >>> from diffusers.utils import load_image >>> resolution = 768 >>> image = load_image( ... "https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" ... ).resize((resolution, resolution)) >>> edit_instruction = "Turn sky into a cloudy one" >>> pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( ... "diffusers/sdxl-instructpix2pix-768", torch_dtype=torch.float16 ... ).to("cuda") >>> edited_image = pipe( ... prompt=edit_instruction, ... image=image, ... height=resolution, ... width=resolution, ... guidance_scale=3.0, ... image_guidance_scale=1.5, ... num_inference_steps=30, ... ).images[0] >>> edited_image ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class StableDiffusionXLInstructPix2PixPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, ): r""" Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config of `stabilityai/stable-diffusion-xl-refiner-1-0`. force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of `stabilityai/stable-diffusion-xl-base-1-0`. add_watermarker (`bool`, *optional*): Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to watermark output images. If not defined, it will default to True if the package is installed, otherwise no watermarker will be used. is_cosxl_edit (`bool`, *optional*): When set the image latents are scaled. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, is_cosxl_edit: Optional[bool] = False, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size self.is_cosxl_edit = is_cosxl_edit add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder( text_input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) prompt_embeds_dtype = self.text_encoder_2.dtype if self.text_encoder_2 is not None else self.unet.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_image_latents( self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: image_latents = image else: # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: image = image.float() self.upcast_vae() image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) if do_classifier_free_guidance: uncond_image_latents = torch.zeros_like(image_latents) image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) if image_latents.dtype != self.vae.dtype: image_latents = image_latents.to(dtype=self.vae.dtype) if self.is_cosxl_edit: image_latents = image_latents * self.vae.config.scaling_factor return image_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 100, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, image_guidance_scale: float = 1.5, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): The image(s) to modify with the pipeline. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. image_guidance_scale (`float`, *optional*, defaults to 1.5): Image guidance scale is to push the generated image towards the initial image `image`. Image guidance scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to generate images that are closely linked to the source image `image`, usually at the expense of lower image quality. This pipeline requires a value of at least `1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). aesthetic_score (`float`, *optional*, defaults to 6.0): Used to simulate an aesthetic score of the generated image by influencing the positive text condition. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_aesthetic_score (`float`, *optional*, defaults to 2.5): Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to simulate an aesthetic score of the generated image by influencing the negative text condition. Examples: Returns: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image image = self.image_processor.preprocess(image, height=height, width=width).to(device) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare Image latents image_latents = self.prepare_image_latents( image, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, do_classifier_free_guidance, ) # 7. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 8. Check that shapes of latents and image match the UNet channels num_channels_image = image_latents.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " f" = {num_channels_latents + num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if do_classifier_free_guidance: # The extra concat similar to how it's done in SD InstructPix2Pix. prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0) add_text_embeds = torch.cat( [add_text_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0 ) add_time_ids = torch.cat([add_time_ids, add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 11. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Expand the latents if we are doing classifier free guidance. # The latents are expanded 3 times because for pix2pix the guidance # is applied for both the text and the input image. latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents # concat latents, image_latents in the channel dimension scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) noise_pred = ( noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_image) + image_guidance_scale * (noise_pred_image - noise_pred_uncond) ) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if XLA_AVAILABLE: xm.mark_step() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) elif latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: return StableDiffusionXLPipelineOutput(images=latents) # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/stable_diffusion_xl/watermark.py ================================================ import numpy as np import torch from ...utils import is_invisible_watermark_available if is_invisible_watermark_available(): from imwatermark import WatermarkEncoder # Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66 WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] class StableDiffusionXLWatermarker: def __init__(self): self.watermark = WATERMARK_BITS self.encoder = WatermarkEncoder() self.encoder.set_watermark("bits", self.watermark) def apply_watermark(self, images: torch.Tensor): # can't encode images that are smaller than 256 if images.shape[-1] < 256: return images images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() # Convert RGB to BGR, which is the channel order expected by the watermark encoder. images = images[:, :, :, ::-1] # Add watermark and convert BGR back to RGB images = [self.encoder.encode(image, "dwtDct")[:, :, ::-1] for image in images] images = np.array(images) images = torch.from_numpy(images).permute(0, 3, 1, 2) images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) return images ================================================ FILE: src/diffusers/pipelines/stable_video_diffusion/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, BaseOutput, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure.update( { "pipeline_stable_video_diffusion": [ "StableVideoDiffusionPipeline", "StableVideoDiffusionPipelineOutput", ], } ) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_stable_video_diffusion import ( StableVideoDiffusionPipeline, StableVideoDiffusionPipelineOutput, ) else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import BaseOutput, logging, replace_example_docstring from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import StableVideoDiffusionPipeline >>> from diffusers.utils import load_image, export_to_video >>> pipe = StableVideoDiffusionPipeline.from_pretrained( ... "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16" ... ) >>> pipe.to("cuda") >>> image = load_image( ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg" ... ) >>> image = image.resize((1024, 576)) >>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] >>> export_to_video(frames, "generated.mp4", fps=7) ``` """ def _append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") return x[(...,) + (None,) * dims_to_append] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps @dataclass class StableVideoDiffusionPipelineOutput(BaseOutput): r""" Output class for Stable Video Diffusion pipeline. Args: frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]): List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, num_frames, height, width, num_channels)`. """ frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor] class StableVideoDiffusionPipeline(DiffusionPipeline): r""" Pipeline to generate video from an input image using Stable Video Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKLTemporalDecoder`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). unet ([`UNetSpatioTemporalConditionModel`]): A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. scheduler ([`EulerDiscreteScheduler`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images. """ model_cpu_offload_seq = "image_encoder->unet->vae" _callback_tensor_inputs = ["latents"] def __init__( self, vae: AutoencoderKLTemporalDecoder, image_encoder: CLIPVisionModelWithProjection, unet: UNetSpatioTemporalConditionModel, scheduler: EulerDiscreteScheduler, feature_extractor: CLIPImageProcessor, ): super().__init__() self.register_modules( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=scheduler, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor) def _encode_image( self, image: PipelineImageInput, device: Union[str, torch.device], num_videos_per_prompt: int, do_classifier_free_guidance: bool, ) -> torch.Tensor: dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.video_processor.pil_to_numpy(image) image = self.video_processor.numpy_to_pt(image) # We normalize the image before resizing to match with the original implementation. # Then we unnormalize it after resizing. image = image * 2.0 - 1.0 image = _resize_with_antialiasing(image, (224, 224)) image = (image + 1.0) / 2.0 # Normalize the image with for CLIP input image = self.feature_extractor( images=image, do_normalize=True, do_center_crop=False, do_resize=False, do_rescale=False, return_tensors="pt", ).pixel_values image = image.to(device=device, dtype=dtype) image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.unsqueeze(1) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) if do_classifier_free_guidance: negative_image_embeddings = torch.zeros_like(image_embeddings) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) return image_embeddings def _encode_vae_image( self, image: torch.Tensor, device: Union[str, torch.device], num_videos_per_prompt: int, do_classifier_free_guidance: bool, ): image = image.to(device=device) image_latents = self.vae.encode(image).latent_dist.mode() # duplicate image_latents for each generation per prompt, using mps friendly method image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) if do_classifier_free_guidance: negative_image_latents = torch.zeros_like(image_latents) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_latents = torch.cat([negative_image_latents, image_latents]) return image_latents def _get_add_time_ids( self, fps: int, motion_bucket_id: int, noise_aug_strength: float, dtype: torch.dtype, batch_size: int, num_videos_per_prompt: int, do_classifier_free_guidance: bool, ): add_time_ids = [fps, motion_bucket_id, noise_aug_strength] passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) if do_classifier_free_guidance: add_time_ids = torch.cat([add_time_ids, add_time_ids]) return add_time_ids def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14): # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] latents = latents.flatten(0, 1) latents = 1 / self.vae.config.scaling_factor * latents forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) # decode decode_chunk_size frames at a time to avoid OOM frames = [] for i in range(0, latents.shape[0], decode_chunk_size): num_frames_in = latents[i : i + decode_chunk_size].shape[0] decode_kwargs = {} if accepts_num_frames: # we only pass num_frames_in if it's expected decode_kwargs["num_frames"] = num_frames_in frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample frames.append(frame) frames = torch.cat(frames, dim=0) # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 frames = frames.float() return frames def check_inputs(self, image, height, width): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") def prepare_latents( self, batch_size: int, num_frames: int, num_channels_latents: int, height: int, width: int, dtype: torch.dtype, device: Union[str, torch.device], generator: torch.Generator, latents: Optional[torch.Tensor] = None, ): shape = ( batch_size, num_frames, num_channels_latents // 2, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @property def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): if isinstance(self.guidance_scale, (int, float)): return self.guidance_scale > 1 return self.guidance_scale.max() > 1 @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], height: int = 576, width: int = 1024, num_frames: Optional[int] = None, num_inference_steps: int = 25, sigmas: Optional[List[float]] = None, min_guidance_scale: float = 1.0, max_guidance_scale: float = 3.0, fps: int = 7, motion_bucket_id: int = 127, noise_aug_strength: float = 0.02, decode_chunk_size: Optional[int] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], return_dict: bool = True, ): r""" The call function to the pipeline for generation. Args: image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`): Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0, 1]`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_frames (`int`, *optional*): The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`). num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. This parameter is modulated by `strength`. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. min_guidance_scale (`float`, *optional*, defaults to 1.0): The minimum guidance scale. Used for the classifier free guidance with first frame. max_guidance_scale (`float`, *optional*, defaults to 3.0): The maximum guidance scale. Used for the classifier free guidance with last frame. fps (`int`, *optional*, defaults to 7): Frames per second. The rate at which the generated images shall be exported to a video after generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. motion_bucket_id (`int`, *optional*, defaults to 127): Used for conditioning the amount of motion for the generation. The higher the number the more motion will be in the video. noise_aug_strength (`float`, *optional*, defaults to 0.02): The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. decode_chunk_size (`int`, *optional*): The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality. For lower memory usage, reduce `decode_chunk_size`. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `pil`, `np` or `pt`. callback_on_step_end (`Callable`, *optional*): A function that is called at the end of each denoising step during inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.Tensor`) is returned. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor num_frames = num_frames if num_frames is not None else self.unet.config.num_frames decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width) # 2. Define call parameters if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): batch_size = len(image) else: batch_size = image.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. self._guidance_scale = max_guidance_scale # 3. Encode input image image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here. # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 fps = fps - 1 # 4. Encode input image using VAE image = self.video_processor.preprocess(image, height=height, width=width).to(device) noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) image = image + noise_aug_strength * noise needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.vae.to(dtype=torch.float32) image_latents = self._encode_vae_image( image, device=device, num_videos_per_prompt=num_videos_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, ) image_latents = image_latents.to(image_embeddings.dtype) # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) # Repeat the image latents for each frame so we can concatenate them with the noise # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) # 5. Get Added Time IDs added_time_ids = self._get_add_time_ids( fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt, self.do_classifier_free_guidance, ) added_time_ids = added_time_ids.to(device) # 6. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas) # 7. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_frames, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 8. Prepare guidance scale guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) guidance_scale = guidance_scale.to(device, latents.dtype) guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) guidance_scale = _append_dims(guidance_scale, latents.ndim) self._guidance_scale = guidance_scale # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # Concatenate image_latents over channels dimension latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if not output_type == "latent": # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) frames = self.decode_latents(latents, num_frames, decode_chunk_size) frames = self.video_processor.postprocess_video(video=frames, output_type=output_type) else: frames = latents self.maybe_free_model_hooks() if not return_dict: return frames return StableVideoDiffusionPipelineOutput(frames=frames) # resizing utils # TODO: clean up later def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): h, w = input.shape[-2:] factors = (h / size[0], w / size[1]) # First, we have to determine sigma # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 sigmas = ( max((factors[0] - 1.0) / 2.0, 0.001), max((factors[1] - 1.0) / 2.0, 0.001), ) # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) # Make sure it is odd if (ks[0] % 2) == 0: ks = ks[0] + 1, ks[1] if (ks[1] % 2) == 0: ks = ks[0], ks[1] + 1 input = _gaussian_blur2d(input, ks, sigmas) output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) return output def _compute_padding(kernel_size): """Compute padding tuple.""" # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad if len(kernel_size) < 2: raise AssertionError(kernel_size) computed = [k - 1 for k in kernel_size] # for even kernels we need to do asymmetric padding :( out_padding = 2 * len(kernel_size) * [0] for i in range(len(kernel_size)): computed_tmp = computed[-(i + 1)] pad_front = computed_tmp // 2 pad_rear = computed_tmp - pad_front out_padding[2 * i + 0] = pad_front out_padding[2 * i + 1] = pad_rear return out_padding def _filter2d(input, kernel): # prepare kernel b, c, h, w = input.shape tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) height, width = tmp_kernel.shape[-2:] padding_shape: List[int] = _compute_padding([height, width]) input = torch.nn.functional.pad(input, padding_shape, mode="reflect") # kernel and input tensor reshape to align element-wise or batch-wise params tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) # convolve the tensor with the kernel. output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) out = output.view(b, c, h, w) return out def _gaussian(window_size: int, sigma): if isinstance(sigma, float): sigma = torch.tensor([[sigma]]) batch_size = sigma.shape[0] x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) if window_size % 2 == 0: x = x + 0.5 gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) return gauss / gauss.sum(-1, keepdim=True) def _gaussian_blur2d(input, kernel_size, sigma): if isinstance(sigma, tuple): sigma = torch.tensor([sigma], dtype=input.dtype) else: sigma = sigma.to(dtype=input.dtype) ky, kx = int(kernel_size[0]), int(kernel_size[1]) bs = sigma.shape[0] kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) out_x = _filter2d(input, kernel_x[..., None, :]) out = _filter2d(out_x, kernel_y[..., None]) return out ================================================ FILE: src/diffusers/pipelines/t2i_adapter/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_stable_diffusion_adapter"] = ["StableDiffusionAdapterPipeline"] _import_structure["pipeline_stable_diffusion_xl_adapter"] = ["StableDiffusionXLAdapterPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline from .pipeline_stable_diffusion_xl_adapter import StableDiffusionXLAdapterPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py ================================================ # Copyright 2024 TencentARC and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, USE_PEFT_BACKEND, BaseOutput, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @dataclass class StableDiffusionAdapterPipelineOutput(BaseOutput): """ Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from PIL import Image >>> from diffusers.utils import load_image >>> import torch >>> from diffusers import StableDiffusionAdapterPipeline, T2IAdapter >>> image = load_image( ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png" ... ) >>> color_palette = image.resize((8, 8)) >>> color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST) >>> adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1", torch_dtype=torch.float16) >>> pipe = StableDiffusionAdapterPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", ... adapter=adapter, ... torch_dtype=torch.float16, ... ) >>> pipe.to("cuda") >>> out_image = pipe( ... "At night, glowing cubes in front of the beach", ... image=color_palette, ... ).images[0] ``` """ def _preprocess_adapter_image(image, height, width): if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] image = [ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image ] # expand [h, w] or [h, w, c] to [b, h, w, c] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): if image[0].ndim == 3: image = torch.stack(image, dim=0) elif image[0].ndim == 4: image = torch.cat(image, dim=0) else: raise ValueError( f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" ) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin): r""" Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter https://arxiv.org/abs/2302.08453 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`): Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a list, the outputs from each Adapter are added together to create one combined additional conditioning. adapter_weights (`List[float]`, *optional*, defaults to None): List of floats representing the weight which will be multiply to each adapter's output before adding them together. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->adapter->unet->vae" _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(adapter, (list, tuple)): adapter = MultiAdapter(adapter) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, adapter=adapter, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, image, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if isinstance(self.adapter, MultiAdapter): if not isinstance(image, list): raise ValueError( "MultiAdapter is enabled, but `image` is not a list. Please pass a list of images to `image`." ) if len(image) != len(self.adapter.adapters): raise ValueError( f"MultiAdapter requires passing the same number of images as adapters. Given {len(image)} images and {len(self.adapter.adapters)} adapters." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _default_height_width(self, height, width, image): # NOTE: It is possible that a list of images have different # dimensions for each image, so just checking the first image # is not _exactly_ correct, but it is simple. while isinstance(image, list): image = image[0] if height is None: if isinstance(image, PIL.Image.Image): height = image.height elif isinstance(image, torch.Tensor): height = image.shape[-2] # round down to nearest multiple of `self.adapter.downscale_factor` height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor if width is None: if isinstance(image, PIL.Image.Image): width = image.width elif isinstance(image, torch.Tensor): width = image.shape[-1] # round down to nearest multiple of `self.adapter.downscale_factor` width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor return height, width # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, adapter_conditioning_scale: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`): The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the type is specified as `torch.Tensor`, it is passed to Adapter as is. PIL.Image.Image` can also be accepted as an image. The control image is automatically resized to fit the output image. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the residual in the original unet. If multiple adapters are specified in init, you can set the corresponding scale as a list. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height, width = self._default_height_width(height, width, image) device = self._execution_device # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds ) self._guidance_scale = guidance_scale if isinstance(self.adapter, MultiAdapter): adapter_input = [] for one_image in image: one_image = _preprocess_adapter_image(one_image, height, width) one_image = one_image.to(device=device, dtype=self.adapter.dtype) adapter_input.append(one_image) else: adapter_input = _preprocess_adapter_image(image, height, width) adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Denoising loop if isinstance(self.adapter, MultiAdapter): adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) for k, v in enumerate(adapter_state): adapter_state[k] = v else: adapter_state = self.adapter(adapter_input) for k, v in enumerate(adapter_state): adapter_state[k] = v * adapter_conditioning_scale if num_images_per_prompt > 1: for k, v in enumerate(adapter_state): adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) if self.do_classifier_free_guidance: for k, v in enumerate(adapter_state): adapter_state[k] = torch.cat([v] * 2, dim=0) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, down_intrablock_additional_residuals=[state.clone() for state in adapter_state], return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if output_type == "latent": image = latents has_nsfw_concept = None elif output_type == "pil": # 8. Post-processing image = self.decode_latents(latents) # 9. Run safety checker image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL image = self.numpy_to_pil(image) else: # 8. Post-processing image = self.decode_latents(latents) # 9. Run safety checker image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionAdapterPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py ================================================ # Copyright 2024 TencentARC and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, DDPMScheduler >>> from diffusers.utils import load_image >>> sketch_image = load_image("https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch.png").convert("L") >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0" >>> adapter = T2IAdapter.from_pretrained( ... "Adapter/t2iadapter", ... subfolder="sketch_sdxl_1.0", ... torch_dtype=torch.float16, ... adapter_type="full_adapter_xl", ... ) >>> scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") >>> pipe = StableDiffusionXLAdapterPipeline.from_pretrained( ... model_id, adapter=adapter, torch_dtype=torch.float16, variant="fp16", scheduler=scheduler ... ).to("cuda") >>> generator = torch.manual_seed(42) >>> sketch_image_out = pipe( ... prompt="a photo of a dog in real world, high quality", ... negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality", ... image=sketch_image, ... generator=generator, ... guidance_scale=7.5, ... ).images[0] ``` """ def _preprocess_adapter_image(image, height, width): if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] image = [ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image ] # expand [h, w] or [h, w, c] to [b, h, w, c] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): if image[0].ndim == 3: image = torch.stack(image, dim=0) elif image[0].ndim == 4: image = torch.cat(image, dim=0) else: raise ValueError( f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}" ) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class StableDiffusionXLAdapterPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter https://arxiv.org/abs/2302.08453 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`): Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a list, the outputs from each Adapter are added together to create one combined additional conditioning. adapter_weights (`List[float]`, *optional*, defaults to None): List of floats representing the weight which will be multiply to each adapter's output before adding them together. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "feature_extractor", "image_encoder", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]], scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, adapter=adapter, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs def check_inputs( self, prompt, prompt_2, height, width, callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width def _default_height_width(self, height, width, image): # NOTE: It is possible that a list of images have different # dimensions for each image, so just checking the first image # is not _exactly_ correct, but it is simple. while isinstance(image, list): image = image[0] if height is None: if isinstance(image, PIL.Image.Image): height = image.height elif isinstance(image, torch.Tensor): height = image.shape[-2] # round down to nearest multiple of `self.adapter.downscale_factor` height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor if width is None: if isinstance(image, PIL.Image.Image): width = image.width elif isinstance(image, torch.Tensor): width = image.shape[-1] # round down to nearest multiple of `self.adapter.downscale_factor` width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor return height, width # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, adapter_conditioning_scale: Union[float, List[float]] = 1.0, adapter_conditioning_factor: float = 1.0, clip_skip: Optional[int] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`): The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the type is specified as `torch.Tensor`, it is passed to Adapter as is. PIL.Image.Image` can also be accepted as an image. The control image is automatically resized to fit the output image. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a specific image resolution. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): To negatively condition the generation process based on a target image resolution. It should be as same as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the residual in the original unet. If multiple adapters are specified in init, you can set the corresponding scale as a list. adapter_conditioning_factor (`float`, *optional*, defaults to 1.0): The fraction of timesteps for which adapter should be applied. If `adapter_conditioning_factor` is `0.0`, adapter is not applied at all. If `adapter_conditioning_factor` is `1.0`, adapter is applied for all timesteps. If `adapter_conditioning_factor` is `0.5`, adapter is applied for half of the timesteps. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height, width = self._default_height_width(height, width, image) device = self._execution_device if isinstance(self.adapter, MultiAdapter): adapter_input = [] for one_image in image: one_image = _preprocess_adapter_image(one_image, height, width) one_image = one_image.to(device=device, dtype=self.adapter.dtype) adapter_input.append(one_image) else: adapter_input = _preprocess_adapter_image(image, height, width) adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype) original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, ) self._guidance_scale = guidance_scale # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3.1 Encode input prompt ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, clip_skip=clip_skip, ) # 3.2 Encode ip_adapter_image if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6.1 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Prepare added time ids & embeddings & adapter features if isinstance(self.adapter, MultiAdapter): adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) for k, v in enumerate(adapter_state): adapter_state[k] = v else: adapter_state = self.adapter(adapter_input) for k, v in enumerate(adapter_state): adapter_state[k] = v * adapter_conditioning_scale if num_images_per_prompt > 1: for k, v in enumerate(adapter_state): adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) if self.do_classifier_free_guidance: for k, v in enumerate(adapter_state): adapter_state[k] = torch.cat([v] * 2, dim=0) add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # Apply denoising_end if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds # predict the noise residual if i < int(num_inference_steps * adapter_conditioning_factor): down_intrablock_additional_residuals = [state.clone() for state in adapter_state] else: down_intrablock_additional_residuals = None noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, down_intrablock_additional_residuals=down_intrablock_additional_residuals, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents return StableDiffusionXLPipelineOutput(images=image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/text_to_video_synthesis/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_output"] = ["TextToVideoSDPipelineOutput"] _import_structure["pipeline_text_to_video_synth"] = ["TextToVideoSDPipeline"] _import_structure["pipeline_text_to_video_synth_img2img"] = ["VideoToVideoSDPipeline"] _import_structure["pipeline_text_to_video_zero"] = ["TextToVideoZeroPipeline"] _import_structure["pipeline_text_to_video_zero_sdxl"] = ["TextToVideoZeroSDXLPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_output import TextToVideoSDPipelineOutput from .pipeline_text_to_video_synth import TextToVideoSDPipeline from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline from .pipeline_text_to_video_zero import TextToVideoZeroPipeline from .pipeline_text_to_video_zero_sdxl import TextToVideoZeroSDXLPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py ================================================ from dataclasses import dataclass from typing import List, Union import numpy as np import PIL import torch from ...utils import ( BaseOutput, ) @dataclass class TextToVideoSDPipelineOutput(BaseOutput): """ Output class for text-to-video pipelines. Args: frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape `(batch_size, num_frames, channels, height, width)` """ frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] ================================================ FILE: src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPTextModel, CLIPTokenizer from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import TextToVideoSDPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import TextToVideoSDPipeline >>> from diffusers.utils import export_to_video >>> pipe = TextToVideoSDPipeline.from_pretrained( ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "Spiderman is surfing" >>> video_frames = pipe(prompt).frames[0] >>> video_path = export_to_video(video_frames) >>> video_path ``` """ class TextToVideoSDPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin ): r""" Pipeline for text-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): A [`~transformers.CLIPTokenizer`] to tokenize text. unet ([`UNet3DConditionModel`]): A [`UNet3DConditionModel`] to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->unet->vae" def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet3DConditionModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_frames: int = 16, num_inference_steps: int = 50, guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated video. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated video. num_frames (`int`, *optional*, defaults to 16): The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated video. Choose between `torch.Tensor` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor num_images_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # reshape latents bsz, channel, frames, width, height = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # reshape latents back latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 8. Post processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return TextToVideoSDPipelineOutput(frames=video) ================================================ FILE: src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import TextToVideoSDPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler >>> from diffusers.utils import export_to_video >>> pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16) >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe.to("cuda") >>> prompt = "spiderman running in the desert" >>> video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=24).frames[0] >>> # safe low-res video >>> video_path = export_to_video(video_frames, output_video_path="./video_576_spiderman.mp4") >>> # let's offload the text-to-image model >>> pipe.to("cpu") >>> # and load the image-to-image model >>> pipe = DiffusionPipeline.from_pretrained( ... "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/15" ... ) >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # The VAE consumes A LOT of memory, let's make sure we run it in sliced mode >>> pipe.vae.enable_slicing() >>> # now let's upscale it >>> video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames] >>> # and denoise it >>> video_frames = pipe(prompt, video=video, strength=0.6).frames[0] >>> video_path = export_to_video(video_frames, output_video_path="./video_1024_spiderman.mp4") >>> video_path ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") class VideoToVideoSDPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin ): r""" Pipeline for text-guided video-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): A [`~transformers.CLIPTokenizer`] to tokenize text. unet ([`UNet3DConditionModel`]): A [`UNet3DConditionModel`] to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->unet->vae" def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet3DConditionModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def prepare_latents(self, video, timestep, batch_size, dtype, device, generator=None): video = video.to(device=device, dtype=dtype) # change from (b, c, f, h, w) -> (b * f, c, w, h) bsz, channel, frames, width, height = video.shape video = video.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) if video.shape[1] == 4: init_latents = video else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(video), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents latents = latents[None, :].reshape((bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4) return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, video: Union[List[np.ndarray], torch.Tensor] = None, strength: float = 0.6, num_inference_steps: int = 50, guidance_scale: float = 15.0, negative_prompt: Optional[Union[str, List[str]]] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. video (`List[np.ndarray]` or `torch.Tensor`): `video` frames or tensor representing a video batch to be used as the starting point for the process. Can also accept video latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `video`. Must be between 0 and 1. `video` is used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `video`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in video generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated video. Choose between `torch.Tensor` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. Examples: Returns: [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet num_images_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess video video = self.video_processor.preprocess_video(video) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # reshape latents bsz, channel, frames, width, height = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # reshape latents back latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") # 9. Post processing if output_type == "latent": video = latents else: video_tensor = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return TextToVideoSDPipelineOutput(frames=video) ================================================ FILE: src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py ================================================ import copy import inspect from dataclasses import dataclass from typing import Callable, List, Optional, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from torch.nn.functional import grid_sample from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name def rearrange_0(tensor, f): F, C, H, W = tensor.size() tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4)) return tensor def rearrange_1(tensor): B, C, F, H, W = tensor.size() return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W)) def rearrange_3(tensor, f): F, D, C = tensor.size() return torch.reshape(tensor, (F // f, f, D, C)) def rearrange_4(tensor): B, F, D, C = tensor.size() return torch.reshape(tensor, (B * F, D, C)) class CrossFrameAttnProcessor: """ Cross frame attention processor. Each frame attends the first frame. Args: batch_size: The number that represents actual batch size, other than the frames. For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to 2, due to classifier-free guidance. """ def __init__(self, batch_size=2): self.batch_size = batch_size def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Cross Frame Attention if not is_cross_attention: video_length = key.size()[0] // self.batch_size first_frame_index = [0] * video_length # rearrange keys to have batch and frames in the 1st and 2nd dims respectively key = rearrange_3(key, video_length) key = key[:, first_frame_index] # rearrange values to have batch and frames in the 1st and 2nd dims respectively value = rearrange_3(value, video_length) value = value[:, first_frame_index] # rearrange back to original shape key = rearrange_4(key) value = rearrange_4(value) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class CrossFrameAttnProcessor2_0: """ Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0. Args: batch_size: The number that represents actual batch size, other than the frames. For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to 2, due to classifier-free guidance. """ def __init__(self, batch_size=2): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.batch_size = batch_size def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Cross Frame Attention if not is_cross_attention: video_length = max(1, key.size()[0] // self.batch_size) first_frame_index = [0] * video_length # rearrange keys to have batch and frames in the 1st and 2nd dims respectively key = rearrange_3(key, video_length) key = key[:, first_frame_index] # rearrange values to have batch and frames in the 1st and 2nd dims respectively value = rearrange_3(value, video_length) value = value[:, first_frame_index] # rearrange back to original shape key = rearrange_4(key) value = rearrange_4(value) head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states @dataclass class TextToVideoPipelineOutput(BaseOutput): r""" Output class for zero-shot text-to-video pipeline. Args: images (`[List[PIL.Image.Image]`, `np.ndarray`]): List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. nsfw_content_detected (`[List[bool]]`): List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] def coords_grid(batch, ht, wd, device): # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) def warp_single_latent(latent, reference_flow): """ Warp latent of a single frame with given flow Args: latent: latent code of a single frame reference_flow: flow which to warp the latent with Returns: warped: warped latent """ _, _, H, W = reference_flow.size() _, _, h, w = latent.size() coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype) coords_t0 = coords0 + reference_flow coords_t0[:, 0] /= W coords_t0[:, 1] /= H coords_t0 = coords_t0 * 2.0 - 1.0 coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear") coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1)) warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection") return warped def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype): """ Create translation motion field Args: motion_field_strength_x: motion strength along x-axis motion_field_strength_y: motion strength along y-axis frame_ids: indexes of the frames the latents of which are being processed. This is needed when we perform chunk-by-chunk inference device: device dtype: dtype Returns: """ seq_length = len(frame_ids) reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype) for fr_idx in range(seq_length): reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx]) reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx]) return reference_flow def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents): """ Creates translation motion and warps the latents accordingly Args: motion_field_strength_x: motion strength along x-axis motion_field_strength_y: motion strength along y-axis frame_ids: indexes of the frames the latents of which are being processed. This is needed when we perform chunk-by-chunk inference latents: latent codes of frames Returns: warped_latents: warped latents """ motion_field = create_motion_field( motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, frame_ids=frame_ids, device=latents.device, dtype=latents.dtype, ) warped_latents = latents.clone().detach() for i in range(len(warped_latents)): warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None]) return warped_latents class TextToVideoZeroPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin ): r""" Pipeline for zero-shot text-to-video generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer (`CLIPTokenizer`): A [`~transformers.CLIPTokenizer`] to tokenize text. unet ([`UNet2DConditionModel`]): A [`UNet3DConditionModel`] to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`CLIPImageProcessor`]): A [`CLIPImageProcessor`] to extract features from generated images; used as inputs to the `safety_checker`. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def forward_loop(self, x_t0, t0, t1, generator): """ Perform DDPM forward process from time t0 to t1. This is the same as adding noise with corresponding variance. Args: x_t0: Latent code at time t0. t0: Timestep at t0. t1: Timestamp at t1. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. Returns: x_t1: Forward process applied to x_t0 from time t0 to t1. """ eps = randn_tensor(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device) alpha_vec = torch.prod(self.scheduler.alphas[t0:t1]) x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps return x_t1 def backward_loop( self, latents, timesteps, prompt_embeds, guidance_scale, callback, callback_steps, num_warmup_steps, extra_step_kwargs, cross_attention_kwargs=None, ): """ Perform backward process given list of time steps. Args: latents: Latents at time timesteps[0]. timesteps: Time steps along which to perform backward process. prompt_embeds: Pre-generated text embeddings. guidance_scale: A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. extra_step_kwargs: Extra_step_kwargs. cross_attention_kwargs: A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). num_warmup_steps: number of warmup steps. Returns: latents: Latents of backward process output at time timesteps[-1]. """ do_classifier_free_guidance = guidance_scale > 1.0 num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order with self.progress_bar(total=num_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) return latents.clone().detach() # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], video_length: Optional[int] = 8, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, motion_field_strength_x: float = 12, motion_field_strength_y: float = 12, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: Optional[int] = 1, t0: int = 44, t1: int = 47, frame_ids: Optional[List[int]] = None, ): """ The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. video_length (`int`, *optional*, defaults to 8): The number of generated video frames. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in video generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated video. Choose between `"latent"` and `"np"`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. motion_field_strength_x (`float`, *optional*, defaults to 12): Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. motion_field_strength_y (`float`, *optional*, defaults to 12): Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. t0 (`int`, *optional*, defaults to 44): Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. t1 (`int`, *optional*, defaults to 47): Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. frame_ids (`List[int]`, *optional*): Indexes of the frames that are being generated. This is used when generating longer videos chunk-by-chunk. Returns: [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput`]: The output contains a `ndarray` of the generated video, when `output_type` != `"latent"`, otherwise a latent code of generated videos and a list of `bool`s indicating whether the corresponding generated video contains "not-safe-for-work" (nsfw) content.. """ assert video_length > 0 if frame_ids is None: frame_ids = list(range(video_length)) assert len(frame_ids) == video_length assert num_videos_per_prompt == 1 # set the processor original_attn_proc = self.unet.attn_processors processor = ( CrossFrameAttnProcessor2_0(batch_size=2) if hasattr(F, "scaled_dot_product_attention") else CrossFrameAttnProcessor(batch_size=2) ) self.unet.set_attn_processor(processor) if isinstance(prompt, str): prompt = [prompt] if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] # Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # Encode input prompt prompt_embeds_tuple = self.encode_prompt( prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt ) prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) # Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # Perform the first backward process up to time T_1 x_1_t1 = self.backward_loop( timesteps=timesteps[: -t1 - 1], prompt_embeds=prompt_embeds, latents=latents, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps, ) scheduler_copy = copy.deepcopy(self.scheduler) # Perform the second backward process up to time T_0 x_1_t0 = self.backward_loop( timesteps=timesteps[-t1 - 1 : -t0 - 1], prompt_embeds=prompt_embeds, latents=x_1_t1, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=0, ) # Propagate first frame latents at time T_0 to remaining frames x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1) # Add motion in latents at time T_0 x_2k_t0 = create_motion_field_and_warp_latents( motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, latents=x_2k_t0, frame_ids=frame_ids[1:], ) # Perform forward process up to time T_1 x_2k_t1 = self.forward_loop( x_t0=x_2k_t0, t0=timesteps[-t0 - 1].item(), t1=timesteps[-t1 - 1].item(), generator=generator, ) # Perform backward process from time T_1 to 0 x_1k_t1 = torch.cat([x_1_t1, x_2k_t1]) b, l, d = prompt_embeds.size() prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) self.scheduler = scheduler_copy x_1k_0 = self.backward_loop( timesteps=timesteps[-t1 - 1 :], prompt_embeds=prompt_embeds, latents=x_1k_t1, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=0, ) latents = x_1k_0 # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") torch.cuda.empty_cache() if output_type == "latent": image = latents has_nsfw_concept = None else: image = self.decode_latents(latents) # Run safety checker image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload all models self.maybe_free_model_hooks() # make sure to set the original attention processors back self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image, has_nsfw_concept) return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image ================================================ FILE: src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py ================================================ import copy import inspect from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL import torch import torch.nn.functional as F from torch.nn.functional import grid_sample from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor, ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, BaseOutput, is_invisible_watermark_available, logging, scale_lora_layers, unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_0 def rearrange_0(tensor, f): F, C, H, W = tensor.size() tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4)) return tensor # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_1 def rearrange_1(tensor): B, C, F, H, W = tensor.size() return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W)) # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_3 def rearrange_3(tensor, f): F, D, C = tensor.size() return torch.reshape(tensor, (F // f, f, D, C)) # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_4 def rearrange_4(tensor): B, F, D, C = tensor.size() return torch.reshape(tensor, (B * F, D, C)) # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor class CrossFrameAttnProcessor: """ Cross frame attention processor. Each frame attends the first frame. Args: batch_size: The number that represents actual batch size, other than the frames. For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to 2, due to classifier-free guidance. """ def __init__(self, batch_size=2): self.batch_size = batch_size def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Cross Frame Attention if not is_cross_attention: video_length = key.size()[0] // self.batch_size first_frame_index = [0] * video_length # rearrange keys to have batch and frames in the 1st and 2nd dims respectively key = rearrange_3(key, video_length) key = key[:, first_frame_index] # rearrange values to have batch and frames in the 1st and 2nd dims respectively value = rearrange_3(value, video_length) value = value[:, first_frame_index] # rearrange back to original shape key = rearrange_4(key) value = rearrange_4(value) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor2_0 class CrossFrameAttnProcessor2_0: """ Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0. Args: batch_size: The number that represents actual batch size, other than the frames. For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to 2, due to classifier-free guidance. """ def __init__(self, batch_size=2): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.batch_size = batch_size def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Cross Frame Attention if not is_cross_attention: video_length = max(1, key.size()[0] // self.batch_size) first_frame_index = [0] * video_length # rearrange keys to have batch and frames in the 1st and 2nd dims respectively key = rearrange_3(key, video_length) key = key[:, first_frame_index] # rearrange values to have batch and frames in the 1st and 2nd dims respectively value = rearrange_3(value, video_length) value = value[:, first_frame_index] # rearrange back to original shape key = rearrange_4(key) value = rearrange_4(value) head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states @dataclass class TextToVideoSDXLPipelineOutput(BaseOutput): """ Output class for zero-shot text-to-video pipeline. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ images: Union[List[PIL.Image.Image], np.ndarray] # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.coords_grid def coords_grid(batch, ht, wd, device): # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.warp_single_latent def warp_single_latent(latent, reference_flow): """ Warp latent of a single frame with given flow Args: latent: latent code of a single frame reference_flow: flow which to warp the latent with Returns: warped: warped latent """ _, _, H, W = reference_flow.size() _, _, h, w = latent.size() coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype) coords_t0 = coords0 + reference_flow coords_t0[:, 0] /= W coords_t0[:, 1] /= H coords_t0 = coords_t0 * 2.0 - 1.0 coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear") coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1)) warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection") return warped # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.create_motion_field def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype): """ Create translation motion field Args: motion_field_strength_x: motion strength along x-axis motion_field_strength_y: motion strength along y-axis frame_ids: indexes of the frames the latents of which are being processed. This is needed when we perform chunk-by-chunk inference device: device dtype: dtype Returns: """ seq_length = len(frame_ids) reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype) for fr_idx in range(seq_length): reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx]) reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx]) return reference_flow # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.create_motion_field_and_warp_latents def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents): """ Creates translation motion and warps the latents accordingly Args: motion_field_strength_x: motion strength along x-axis motion_field_strength_y: motion strength along y-axis frame_ids: indexes of the frames the latents of which are being processed. This is needed when we perform chunk-by-chunk inference latents: latent codes of frames Returns: warped_latents: warped latents """ motion_field = create_motion_field( motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, frame_ids=frame_ids, device=latents.device, dtype=latents.dtype, ) warped_latents = latents.clone().detach() for i in range(len(warped_latents)): warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None]) return warped_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class TextToVideoZeroSDXLPipeline( DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ): r""" Pipeline for zero-shot text-to-video generation using Stable Diffusion XL. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion XL uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([` CLIPTextModelWithProjection`]): Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`CLIPTokenizer`): Second Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" _optional_components = [ "tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2", "image_encoder", "feature_extractor", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() if add_watermarker: self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, FusedAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(dtype) self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def check_inputs( self, prompt, prompt_2, height, width, callback_steps, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) elif negative_prompt_2 is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt: str, prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # textual inversion: process multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) if self.text_encoder_2 is not None: prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] if self.text_encoder_2 is not None: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) else: negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if do_classifier_free_guidance: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoZeroPipeline.forward_loop def forward_loop(self, x_t0, t0, t1, generator): """ Perform DDPM forward process from time t0 to t1. This is the same as adding noise with corresponding variance. Args: x_t0: Latent code at time t0. t0: Timestep at t0. t1: Timestamp at t1. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. Returns: x_t1: Forward process applied to x_t0 from time t0 to t1. """ eps = randn_tensor(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device) alpha_vec = torch.prod(self.scheduler.alphas[t0:t1]) x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps return x_t1 def backward_loop( self, latents, timesteps, prompt_embeds, guidance_scale, callback, callback_steps, num_warmup_steps, extra_step_kwargs, add_text_embeds, add_time_ids, cross_attention_kwargs=None, guidance_rescale: float = 0.0, ): """ Perform backward process given list of time steps Args: latents: Latents at time timesteps[0]. timesteps: Time steps along which to perform backward process. prompt_embeds: Pre-generated text embeddings. guidance_scale: A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. extra_step_kwargs: Extra_step_kwargs. cross_attention_kwargs: A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). num_warmup_steps: number of warmup steps. Returns: latents: latents of backward process output at time timesteps[-1] """ do_classifier_free_guidance = guidance_scale > 1.0 num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order with self.progress_bar(total=num_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) return latents.clone().detach() @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], prompt_2: Optional[Union[str, List[str]]] = None, video_length: Optional[int] = 8, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, denoising_end: Optional[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, frame_ids: Optional[List[int]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, latents: Optional[torch.Tensor] = None, motion_field_strength_x: float = 12, motion_field_strength_y: float = 12, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, t0: int = 44, t1: int = 47, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders video_length (`int`, *optional*, defaults to 8): The number of generated video frames. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. frame_ids (`List[int]`, *optional*): Indexes of the frames that are being generated. This is used when generating longer videos chunk-by-chunk. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. motion_field_strength_x (`float`, *optional*, defaults to 12): Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. motion_field_strength_y (`float`, *optional*, defaults to 12): Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). t0 (`int`, *optional*, defaults to 44): Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. t1 (`int`, *optional*, defaults to 47): Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. Returns: [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoSDXLPipelineOutput`] or `tuple`: [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoSDXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ assert video_length > 0 if frame_ids is None: frame_ids = list(range(video_length)) assert len(frame_ids) == video_length assert num_videos_per_prompt == 1 # set the processor original_attn_proc = self.unet.attn_processors processor = ( CrossFrameAttnProcessor2_0(batch_size=2) if hasattr(F, "scaled_dot_product_attention") else CrossFrameAttnProcessor(batch_size=2) ) self.unet.set_attn_processor(processor) if isinstance(prompt, str): prompt = [prompt] if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) # 2. Define call parameters batch_size = ( 1 if isinstance(prompt, str) else len(prompt) if isinstance(prompt, list) else prompt_embeds.shape[0] ) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_videos_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # Perform the first backward process up to time T_1 x_1_t1 = self.backward_loop( timesteps=timesteps[: -t1 - 1], prompt_embeds=prompt_embeds, latents=latents, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps, add_text_embeds=add_text_embeds, add_time_ids=add_time_ids, ) scheduler_copy = copy.deepcopy(self.scheduler) # Perform the second backward process up to time T_0 x_1_t0 = self.backward_loop( timesteps=timesteps[-t1 - 1 : -t0 - 1], prompt_embeds=prompt_embeds, latents=x_1_t1, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=0, add_text_embeds=add_text_embeds, add_time_ids=add_time_ids, ) # Propagate first frame latents at time T_0 to remaining frames x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1) # Add motion in latents at time T_0 x_2k_t0 = create_motion_field_and_warp_latents( motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, latents=x_2k_t0, frame_ids=frame_ids[1:], ) # Perform forward process up to time T_1 x_2k_t1 = self.forward_loop( x_t0=x_2k_t0, t0=timesteps[-t0 - 1].to(torch.long), t1=timesteps[-t1 - 1].to(torch.long), generator=generator, ) # Perform backward process from time T_1 to 0 latents = torch.cat([x_1_t1, x_2k_t1]) self.scheduler = scheduler_copy timesteps = timesteps[-t1 - 1 :] b, l, d = prompt_embeds.size() prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) b, k = add_text_embeds.size() add_text_embeds = add_text_embeds[:, None].repeat(1, video_length, 1).reshape(b * video_length, k) b, k = add_time_ids.size() add_time_ids = add_time_ids[:, None].repeat(1, video_length, 1).reshape(b * video_length, k) # 7.1 Apply denoising_end if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] x_1k_0 = self.backward_loop( timesteps=timesteps, prompt_embeds=prompt_embeds, latents=latents, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=0, add_text_embeds=add_text_embeds, add_time_ids=add_time_ids, ) latents = x_1k_0 if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) else: image = latents return TextToVideoSDXLPipelineOutput(images=image) # apply watermark if available if self.watermark is not None: image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) self.maybe_free_model_hooks() # make sure to set the original attention processors back self.unet.set_attn_processor(original_attn_proc) if not return_dict: return (image,) return TextToVideoSDXLPipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/unclip/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_transformers_available, is_transformers_version, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline _dummy_objects.update( {"UnCLIPImageVariationPipeline": UnCLIPImageVariationPipeline, "UnCLIPPipeline": UnCLIPPipeline} ) else: _import_structure["pipeline_unclip"] = ["UnCLIPPipeline"] _import_structure["pipeline_unclip_image_variation"] = ["UnCLIPImageVariationPipeline"] _import_structure["text_proj"] = ["UnCLIPTextProjModel"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_unclip import UnCLIPPipeline from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline from .text_proj import UnCLIPTextProjModel else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/unclip/pipeline_unclip.py ================================================ # Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Tuple, Union import torch from torch.nn import functional as F from transformers import CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...schedulers import UnCLIPScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_proj import UnCLIPTextProjModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class UnCLIPPipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using unCLIP. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: text_encoder ([`~transformers.CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. prior ([`PriorTransformer`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. text_proj ([`UnCLIPTextProjModel`]): Utility class to prepare and combine the embeddings before they are passed to the decoder. decoder ([`UNet2DConditionModel`]): The decoder to invert the image embedding into an image. super_res_first ([`UNet2DModel`]): Super resolution UNet. Used in all but the last step of the super resolution diffusion process. super_res_last ([`UNet2DModel`]): Super resolution UNet. Used in the last step of the super resolution diffusion process. prior_scheduler ([`UnCLIPScheduler`]): Scheduler used in the prior denoising process (a modified [`DDPMScheduler`]). decoder_scheduler ([`UnCLIPScheduler`]): Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]). super_res_scheduler ([`UnCLIPScheduler`]): Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]). """ _exclude_from_cpu_offload = ["prior"] prior: PriorTransformer decoder: UNet2DConditionModel text_proj: UnCLIPTextProjModel text_encoder: CLIPTextModelWithProjection tokenizer: CLIPTokenizer super_res_first: UNet2DModel super_res_last: UNet2DModel prior_scheduler: UnCLIPScheduler decoder_scheduler: UnCLIPScheduler super_res_scheduler: UnCLIPScheduler model_cpu_offload_seq = "text_encoder->text_proj->decoder->super_res_first->super_res_last" def __init__( self, prior: PriorTransformer, decoder: UNet2DConditionModel, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_proj: UnCLIPTextProjModel, super_res_first: UNet2DModel, super_res_last: UNet2DModel, prior_scheduler: UnCLIPScheduler, decoder_scheduler: UnCLIPScheduler, super_res_scheduler: UnCLIPScheduler, ): super().__init__() self.register_modules( prior=prior, decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, text_proj=text_proj, super_res_first=super_res_first, super_res_last=super_res_last, prior_scheduler=prior_scheduler, decoder_scheduler=decoder_scheduler, super_res_scheduler=super_res_scheduler, ) def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None, ): if text_model_output is None: batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_enc_hid_states = text_encoder_output.last_hidden_state else: batch_size = text_model_output[0].shape[0] prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1] text_mask = text_attention_mask prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens = [""] * batch_size uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_enc_hid_states.shape[1] uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1) uncond_text_enc_hid_states = uncond_text_enc_hid_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_enc_hid_states, text_mask @torch.no_grad() def __call__( self, prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, prior_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25, super_res_num_inference_steps: int = 7, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prior_latents: Optional[torch.Tensor] = None, decoder_latents: Optional[torch.Tensor] = None, super_res_latents: Optional[torch.Tensor] = None, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None, prior_guidance_scale: float = 4.0, decoder_guidance_scale: float = 8.0, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide image generation. This can only be left undefined if `text_model_output` and `text_attention_mask` is passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. prior_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps for the prior. More denoising steps usually lead to a higher quality image at the expense of slower inference. decoder_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality image at the expense of slower inference. super_res_num_inference_steps (`int`, *optional*, defaults to 7): The number of denoising steps for super resolution. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prior_latents (`torch.Tensor` of shape (batch size, embeddings dimension), *optional*): Pre-generated noisy latents to be used as inputs for the prior. decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. prior_guidance_scale (`float`, *optional*, defaults to 4.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. decoder_guidance_scale (`float`, *optional*, defaults to 4.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. text_model_output (`CLIPTextModelOutput`, *optional*): Pre-defined [`CLIPTextModel`] outputs that can be derived from the text encoder. Pre-defined text outputs can be passed for tasks like text embedding interpolations. Make sure to also pass `text_attention_mask` in this case. `prompt` can the be left `None`. text_attention_mask (`torch.Tensor`, *optional*): Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention masks are necessary when passing `text_model_output`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ if prompt is not None: if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") else: batch_size = text_model_output[0].shape[0] device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask ) # prior self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) prior_timesteps_tensor = self.prior_scheduler.timesteps embedding_dim = self.prior.config.embedding_dim prior_latents = self.prepare_latents( (batch_size, embedding_dim), prompt_embeds.dtype, device, generator, prior_latents, self.prior_scheduler, ) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_enc_hid_states, attention_mask=text_mask, ).predicted_image_embedding if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == prior_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = prior_timesteps_tensor[i + 1] prior_latents = self.prior_scheduler.step( predicted_image_embedding, timestep=t, sample=prior_latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample prior_latents = self.prior.post_process_latents(prior_latents) image_embeddings = prior_latents # done prior # decoder text_enc_hid_states, additive_clip_time_embeddings = self.text_proj( image_embeddings=image_embeddings, prompt_embeds=prompt_embeds, text_encoder_hidden_states=text_enc_hid_states, do_classifier_free_guidance=do_classifier_free_guidance, ) if device.type == "mps": # HACK: MPS: There is a panic when padding bool tensors, # so cast to int tensor for the pad and back to bool afterwards text_mask = text_mask.type(torch.int) decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) decoder_text_mask = decoder_text_mask.type(torch.bool) else: decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps num_channels_latents = self.decoder.config.in_channels height = self.decoder.config.sample_size width = self.decoder.config.sample_size decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), text_enc_hid_states.dtype, device, generator, decoder_latents, self.decoder_scheduler, ) for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents noise_pred = self.decoder( sample=latent_model_input, timestep=t, encoder_hidden_states=text_enc_hid_states, class_labels=additive_clip_time_embeddings, attention_mask=decoder_text_mask, ).sample if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if i + 1 == decoder_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = decoder_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 decoder_latents = self.decoder_scheduler.step( noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample decoder_latents = decoder_latents.clamp(-1, 1) image_small = decoder_latents # done decoder # super res self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps channels = self.super_res_first.config.in_channels // 2 height = self.super_res_first.config.sample_size width = self.super_res_first.config.sample_size super_res_latents = self.prepare_latents( (batch_size, channels, height, width), image_small.dtype, device, generator, super_res_latents, self.super_res_scheduler, ) if device.type == "mps": # MPS does not support many interpolations image_upscaled = F.interpolate(image_small, size=[height, width]) else: interpolate_antialias = {} if "antialias" in inspect.signature(F.interpolate).parameters: interpolate_antialias["antialias"] = True image_upscaled = F.interpolate( image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias ) for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): # no classifier free guidance if i == super_res_timesteps_tensor.shape[0] - 1: unet = self.super_res_last else: unet = self.super_res_first latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) noise_pred = unet( sample=latent_model_input, timestep=t, ).sample if i + 1 == super_res_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = super_res_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 super_res_latents = self.super_res_scheduler.step( noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample image = super_res_latents # done super res self.maybe_free_model_hooks() # post processing image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py ================================================ # Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Union import PIL.Image import torch from torch.nn import functional as F from transformers import ( CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...models import UNet2DConditionModel, UNet2DModel from ...schedulers import UnCLIPScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_proj import UnCLIPTextProjModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class UnCLIPImageVariationPipeline(DiffusionPipeline): """ Pipeline to generate image variations from an input image using UnCLIP. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: text_encoder ([`~transformers.CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. feature_extractor ([`~transformers.CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `image_encoder`. image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). text_proj ([`UnCLIPTextProjModel`]): Utility class to prepare and combine the embeddings before they are passed to the decoder. decoder ([`UNet2DConditionModel`]): The decoder to invert the image embedding into an image. super_res_first ([`UNet2DModel`]): Super resolution UNet. Used in all but the last step of the super resolution diffusion process. super_res_last ([`UNet2DModel`]): Super resolution UNet. Used in the last step of the super resolution diffusion process. decoder_scheduler ([`UnCLIPScheduler`]): Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]). super_res_scheduler ([`UnCLIPScheduler`]): Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]). """ decoder: UNet2DConditionModel text_proj: UnCLIPTextProjModel text_encoder: CLIPTextModelWithProjection tokenizer: CLIPTokenizer feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection super_res_first: UNet2DModel super_res_last: UNet2DModel decoder_scheduler: UnCLIPScheduler super_res_scheduler: UnCLIPScheduler model_cpu_offload_seq = "text_encoder->image_encoder->text_proj->decoder->super_res_first->super_res_last" def __init__( self, decoder: UNet2DConditionModel, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_proj: UnCLIPTextProjModel, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, super_res_first: UNet2DModel, super_res_last: UNet2DModel, decoder_scheduler: UnCLIPScheduler, super_res_scheduler: UnCLIPScheduler, ): super().__init__() self.register_modules( decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, text_proj=text_proj, feature_extractor=feature_extractor, image_encoder=image_encoder, super_res_first=super_res_first, super_res_last=super_res_last, decoder_scheduler=decoder_scheduler, super_res_scheduler=super_res_scheduler, ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens = [""] * batch_size max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None): dtype = next(self.image_encoder.parameters()).dtype if image_embeddings is None: if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) return image_embeddings @torch.no_grad() def __call__( self, image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor]] = None, num_images_per_prompt: int = 1, decoder_num_inference_steps: int = 25, super_res_num_inference_steps: int = 7, generator: Optional[torch.Generator] = None, decoder_latents: Optional[torch.Tensor] = None, super_res_latents: Optional[torch.Tensor] = None, image_embeddings: Optional[torch.Tensor] = None, decoder_guidance_scale: float = 8.0, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ The call function to the pipeline for generation. Args: image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`): `Image` or tensor representing an image batch to be used as the starting point. If you provide a tensor, it needs to be compatible with the [`CLIPImageProcessor`] [configuration](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). Can be left as `None` only when `image_embeddings` are passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. decoder_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality image at the expense of slower inference. super_res_num_inference_steps (`int`, *optional*, defaults to 7): The number of denoising steps for super resolution. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. decoder_guidance_scale (`float`, *optional*, defaults to 4.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. image_embeddings (`torch.Tensor`, *optional*): Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings can be passed for tasks like image interpolations. `image` can be left as `None`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ if image is not None: if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): batch_size = len(image) else: batch_size = image.shape[0] else: batch_size = image_embeddings.shape[0] prompt = [""] * batch_size device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = decoder_guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance ) image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings) # decoder text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( image_embeddings=image_embeddings, prompt_embeds=prompt_embeds, text_encoder_hidden_states=text_encoder_hidden_states, do_classifier_free_guidance=do_classifier_free_guidance, ) if device.type == "mps": # HACK: MPS: There is a panic when padding bool tensors, # so cast to int tensor for the pad and back to bool afterwards text_mask = text_mask.type(torch.int) decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) decoder_text_mask = decoder_text_mask.type(torch.bool) else: decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps num_channels_latents = self.decoder.config.in_channels height = self.decoder.config.sample_size width = self.decoder.config.sample_size if decoder_latents is None: decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, device, generator, decoder_latents, self.decoder_scheduler, ) for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents noise_pred = self.decoder( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, class_labels=additive_clip_time_embeddings, attention_mask=decoder_text_mask, ).sample if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if i + 1 == decoder_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = decoder_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 decoder_latents = self.decoder_scheduler.step( noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample decoder_latents = decoder_latents.clamp(-1, 1) image_small = decoder_latents # done decoder # super res self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps channels = self.super_res_first.config.in_channels // 2 height = self.super_res_first.config.sample_size width = self.super_res_first.config.sample_size if super_res_latents is None: super_res_latents = self.prepare_latents( (batch_size, channels, height, width), image_small.dtype, device, generator, super_res_latents, self.super_res_scheduler, ) if device.type == "mps": # MPS does not support many interpolations image_upscaled = F.interpolate(image_small, size=[height, width]) else: interpolate_antialias = {} if "antialias" in inspect.signature(F.interpolate).parameters: interpolate_antialias["antialias"] = True image_upscaled = F.interpolate( image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias ) for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): # no classifier free guidance if i == super_res_timesteps_tensor.shape[0] - 1: unet = self.super_res_last else: unet = self.super_res_first latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) noise_pred = unet( sample=latent_model_input, timestep=t, ).sample if i + 1 == super_res_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = super_res_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 super_res_latents = self.super_res_scheduler.step( noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample image = super_res_latents # done super res self.maybe_free_model_hooks() # post processing image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: src/diffusers/pipelines/unclip/text_proj.py ================================================ # Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin class UnCLIPTextProjModel(ModelMixin, ConfigMixin): """ Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the decoder. For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1 """ @register_to_config def __init__( self, *, clip_extra_context_tokens: int = 4, clip_embeddings_dim: int = 768, time_embed_dim: int, cross_attention_dim, ): super().__init__() self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim)) # parameters for additional clip time embeddings self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim) # parameters for encoder hidden states self.clip_extra_context_tokens = clip_extra_context_tokens self.clip_extra_context_tokens_proj = nn.Linear( clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim ) self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim) self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance): if do_classifier_free_guidance: # Add the classifier free guidance embeddings to the image embeddings image_embeddings_batch_size = image_embeddings.shape[0] classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0) classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand( image_embeddings_batch_size, -1 ) image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0) # The image embeddings batch size and the text embeddings batch size are equal assert image_embeddings.shape[0] == prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0] # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and # adding CLIP embeddings to the existing timestep embedding, ... time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds # ... and by projecting CLIP embeddings into four # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1) text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1) return text_encoder_hidden_states, additive_clip_time_embeddings ================================================ FILE: src/diffusers/pipelines/unidiffuser/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( ImageTextPipelineOutput, UniDiffuserPipeline, ) _dummy_objects.update( {"ImageTextPipelineOutput": ImageTextPipelineOutput, "UniDiffuserPipeline": UniDiffuserPipeline} ) else: _import_structure["modeling_text_decoder"] = ["UniDiffuserTextDecoder"] _import_structure["modeling_uvit"] = ["UniDiffuserModel", "UTransformer2DModel"] _import_structure["pipeline_unidiffuser"] = ["ImageTextPipelineOutput", "UniDiffuserPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( ImageTextPipelineOutput, UniDiffuserPipeline, ) else: from .modeling_text_decoder import UniDiffuserTextDecoder from .modeling_uvit import UniDiffuserModel, UTransformer2DModel from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py ================================================ from typing import Optional import numpy as np import torch from torch import nn from transformers import GPT2Config, GPT2LMHeadModel from transformers.modeling_utils import ModuleUtilsMixin from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin # Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): """ Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to generate text from the UniDiffuser image-text embedding. Parameters: prefix_length (`int`): Max number of prefix tokens that will be supplied to the model. prefix_inner_dim (`int`): The hidden size of the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the CLIP text encoder. prefix_hidden_dim (`int`, *optional*): Hidden dim of the MLP if we encode the prefix. vocab_size (`int`, *optional*, defaults to 50257): Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. n_positions (`int`, *optional*, defaults to 1024): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). n_embd (`int`, *optional*, defaults to 768): Dimensionality of the embeddings and hidden states. n_layer (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. n_head (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. n_inner (`int`, *optional*, defaults to None): Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd activation_function (`str`, *optional*, defaults to `"gelu"`): Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. resid_pdrop (`float`, *optional*, defaults to 0.1): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. embd_pdrop (`float`, *optional*, defaults to 0.1): The dropout ratio for the embeddings. attn_pdrop (`float`, *optional*, defaults to 0.1): The dropout ratio for the attention. layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. scale_attn_weights (`bool`, *optional*, defaults to `True`): Scale attention weights by dividing by sqrt(hidden_size).. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): Whether to additionally scale attention weights by `1 / layer_idx + 1`. reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention dot-product/softmax to float() when training with mixed precision. """ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] @register_to_config def __init__( self, prefix_length: int, prefix_inner_dim: int, prefix_hidden_dim: Optional[int] = None, vocab_size: int = 50257, # Start of GPT2 config args n_positions: int = 1024, n_embd: int = 768, n_layer: int = 12, n_head: int = 12, n_inner: Optional[int] = None, activation_function: str = "gelu_new", resid_pdrop: float = 0.1, embd_pdrop: float = 0.1, attn_pdrop: float = 0.1, layer_norm_epsilon: float = 1e-5, initializer_range: float = 0.02, scale_attn_weights: bool = True, use_cache: bool = True, scale_attn_by_inverse_layer_idx: bool = False, reorder_and_upcast_attn: bool = False, ): super().__init__() self.prefix_length = prefix_length if prefix_inner_dim != n_embd and prefix_hidden_dim is None: raise ValueError( f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and" f" `n_embd`: {n_embd} are not equal." ) self.prefix_inner_dim = prefix_inner_dim self.prefix_hidden_dim = prefix_hidden_dim self.encode_prefix = ( nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity() ) self.decode_prefix = ( nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() ) gpt_config = GPT2Config( vocab_size=vocab_size, n_positions=n_positions, n_embd=n_embd, n_layer=n_layer, n_head=n_head, n_inner=n_inner, activation_function=activation_function, resid_pdrop=resid_pdrop, embd_pdrop=embd_pdrop, attn_pdrop=attn_pdrop, layer_norm_epsilon=layer_norm_epsilon, initializer_range=initializer_range, scale_attn_weights=scale_attn_weights, use_cache=use_cache, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, reorder_and_upcast_attn=reorder_and_upcast_attn, ) self.transformer = GPT2LMHeadModel(gpt_config) def forward( self, input_ids: torch.Tensor, prefix_embeds: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, ): """ Args: input_ids (`torch.Tensor` of shape `(N, max_seq_len)`): Text tokens to use for inference. prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`): Prefix embedding to preprend to the embedded tokens. attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): Attention mask for the prefix embedding. labels (`torch.Tensor`, *optional*): Labels to use for language modeling. """ embedding_text = self.transformer.transformer.wte(input_ids) hidden = self.encode_prefix(prefix_embeds) prefix_embeds = self.decode_prefix(hidden) embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1) if labels is not None: dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device) labels = torch.cat((dummy_token, input_ids), dim=1) out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) if self.prefix_hidden_dim is not None: return out, hidden else: return out def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) def encode(self, prefix): return self.encode_prefix(prefix) @torch.no_grad() def generate_captions(self, features, eos_token_id, device): """ Generate captions given text embedding features. Returns list[L]. Args: features (`torch.Tensor` of shape `(B, L, D)`): Text embedding features to generate captions from. eos_token_id (`int`): The token ID of the EOS token for the text decoder model. device: Device to perform text generation on. Returns: `List[str]`: A list of strings generated from the decoder model. """ features = torch.split(features, 1, dim=0) generated_tokens = [] generated_seq_lengths = [] for feature in features: feature = self.decode_prefix(feature.to(device)) # back to the clip feature # Only support beam search for now output_tokens, seq_lengths = self.generate_beam( input_embeds=feature, device=device, eos_token_id=eos_token_id ) generated_tokens.append(output_tokens[0]) generated_seq_lengths.append(seq_lengths[0]) generated_tokens = torch.stack(generated_tokens) generated_seq_lengths = torch.stack(generated_seq_lengths) return generated_tokens, generated_seq_lengths @torch.no_grad() def generate_beam( self, input_ids=None, input_embeds=None, device=None, beam_size: int = 5, entry_length: int = 67, temperature: float = 1.0, eos_token_id: Optional[int] = None, ): """ Generates text using the given tokenizer and text prompt or token embedding via beam search. This implementation is based on the beam search implementation from the [original UniDiffuser code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89). Args: eos_token_id (`int`, *optional*): The token ID of the EOS token for the text decoder model. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds` must be supplied. input_embeds (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): An embedded representation to directly pass to the transformer as a prefix for beam search. One of `input_ids` and `input_embeds` must be supplied. device: The device to perform beam search on. beam_size (`int`, *optional*, defaults to `5`): The number of best states to store during beam search. entry_length (`int`, *optional*, defaults to `67`): The number of iterations to run beam search. temperature (`float`, *optional*, defaults to 1.0): The temperature to use when performing the softmax over logits from the decoding model. Returns: `Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated token sequences sorted by score in descending order, and the second element is the sequence lengths corresponding to those sequences. """ # Generates text until stop_token is reached using beam search with the desired beam size. stop_token_index = eos_token_id tokens = None scores = None seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int) is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) if input_embeds is not None: generated = input_embeds else: generated = self.transformer.transformer.wte(input_ids) for i in range(entry_length): outputs = self.transformer(inputs_embeds=generated) logits = outputs.logits logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) logits = logits.softmax(-1).log() if scores is None: scores, next_tokens = logits.topk(beam_size, -1) generated = generated.expand(beam_size, *generated.shape[1:]) next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) if tokens is None: tokens = next_tokens else: tokens = tokens.expand(beam_size, *tokens.shape[1:]) tokens = torch.cat((tokens, next_tokens), dim=1) else: logits[is_stopped] = -float(np.inf) logits[is_stopped, 0] = 0 scores_sum = scores[:, None] + logits seq_lengths[~is_stopped] += 1 scores_sum_average = scores_sum / seq_lengths[:, None] scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) next_tokens_source = next_tokens // scores_sum.shape[1] seq_lengths = seq_lengths[next_tokens_source] next_tokens = next_tokens % scores_sum.shape[1] next_tokens = next_tokens.unsqueeze(1) tokens = tokens[next_tokens_source] tokens = torch.cat((tokens, next_tokens), dim=1) generated = generated[next_tokens_source] scores = scores_sum_average * seq_lengths is_stopped = is_stopped[next_tokens_source] next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) generated = torch.cat((generated, next_token_embed), dim=1) is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() if is_stopped.all(): break scores = scores / seq_lengths order = scores.argsort(descending=True) # tokens tensors are already padded to max_seq_length output_texts = [tokens[i] for i in order] output_texts = torch.stack(output_texts, dim=0) seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype) return output_texts, seq_lengths ================================================ FILE: src/diffusers/pipelines/unidiffuser/modeling_uvit.py ================================================ import math from typing import Optional, Union import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.attention import FeedForward from ...models.attention_processor import Attention from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed from ...models.modeling_outputs import Transformer2DModelOutput from ...models.normalization import AdaLayerNorm from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): logger.warning( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect." ) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # type: (torch.Tensor, float, float, float, float) -> torch.Tensor r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) class PatchEmbed(nn.Module): """2D Image to Patch Embedding""" def __init__( self, height=224, width=224, patch_size=16, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, use_pos_embed=True, ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None self.use_pos_embed = use_pos_embed if self.use_pos_embed: pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) def forward(self, latent): latent = self.proj(latent) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: latent = self.norm(latent) if self.use_pos_embed: return latent + self.pos_embed else: return latent class SkipBlock(nn.Module): def __init__(self, dim: int): super().__init__() self.skip_linear = nn.Linear(2 * dim, dim) # Use torch.nn.LayerNorm for now, following the original code self.norm = nn.LayerNorm(dim) def forward(self, x, skip): x = self.skip_linear(torch.cat([x, skip], dim=-1)) x = self.norm(x) return x # Modified to support both pre-LayerNorm and post-LayerNorm configurations # Don't support AdaLayerNormZero for now # Modified from diffusers.models.attention.BasicTransformerBlock class UTransformerBlock(nn.Module): r""" A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (:obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (:obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float32 when performing the attention calculation. norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. norm_type (`str`, defaults to `"layer_norm"`): The layer norm implementation to use. pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g. `pre_layer_norm = True`. final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", pre_layer_norm: bool = True, final_dropout: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.pre_layer_norm = pre_layer_norm if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) # 1. Self-Attn self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.attn2 = None if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = ( AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) ) else: self.norm2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) def forward( self, hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, timestep=None, cross_attention_kwargs=None, class_labels=None, ): # Pre-LayerNorm if self.pre_layer_norm: if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) else: norm_hidden_states = self.norm1(hidden_states) else: norm_hidden_states = hidden_states # 1. Self-Attention cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) # Post-LayerNorm if not self.pre_layer_norm: if self.use_ada_layer_norm: attn_output = self.norm1(attn_output, timestep) else: attn_output = self.norm1(attn_output) hidden_states = attn_output + hidden_states if self.attn2 is not None: # Pre-LayerNorm if self.pre_layer_norm: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) else: norm_hidden_states = hidden_states # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly # prepare attention mask here # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) # Post-LayerNorm if not self.pre_layer_norm: attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output) hidden_states = attn_output + hidden_states # 3. Feed-forward # Pre-LayerNorm if self.pre_layer_norm: norm_hidden_states = self.norm3(hidden_states) else: norm_hidden_states = hidden_states ff_output = self.ff(norm_hidden_states) # Post-LayerNorm if not self.pre_layer_norm: ff_output = self.norm3(ff_output) hidden_states = ff_output + hidden_states return hidden_states # Like UTransformerBlock except with LayerNorms on the residual backbone of the block # Modified from diffusers.models.attention.BasicTransformerBlock class UniDiffuserBlock(nn.Module): r""" A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104). Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (:obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (:obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float() when performing the attention calculation. norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. norm_type (`str`, defaults to `"layer_norm"`): The layer norm implementation to use. pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm (`pre_layer_norm = False`). final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", pre_layer_norm: bool = False, final_dropout: bool = True, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.pre_layer_norm = pre_layer_norm if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) # 1. Self-Attn self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.attn2 = None if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = ( AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) ) else: self.norm2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) def forward( self, hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, timestep=None, cross_attention_kwargs=None, class_labels=None, ): # Following the diffusers transformer block implementation, put the LayerNorm on the # residual backbone # Pre-LayerNorm if self.pre_layer_norm: if self.use_ada_layer_norm: hidden_states = self.norm1(hidden_states, timestep) else: hidden_states = self.norm1(hidden_states) # 1. Self-Attention cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # Following the diffusers transformer block implementation, put the LayerNorm on the # residual backbone # Post-LayerNorm if not self.pre_layer_norm: if self.use_ada_layer_norm: hidden_states = self.norm1(hidden_states, timestep) else: hidden_states = self.norm1(hidden_states) if self.attn2 is not None: # Pre-LayerNorm if self.pre_layer_norm: hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly # prepare attention mask here # 2. Cross-Attention attn_output = self.attn2( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # Post-LayerNorm if not self.pre_layer_norm: hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) # 3. Feed-forward # Pre-LayerNorm if self.pre_layer_norm: hidden_states = self.norm3(hidden_states) ff_output = self.ff(hidden_states) hidden_states = ff_output + hidden_states # Post-LayerNorm if not self.pre_layer_norm: hidden_states = self.norm3(hidden_states) return hidden_states # Modified from diffusers.models.transformer_2d.Transformer2DModel # Modify the transformer block structure to be U-Net like following U-ViT # Only supports patch-style input and torch.nn.LayerNorm currently # https://github.com/baofff/U-ViT class UTransformer2DModel(ModelMixin, ConfigMixin): """ Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion, similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`] layer and then reshaped to (b, t, d). Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input. out_channels (`int`, *optional*): The number of output channels; if `None`, defaults to `in_channels`. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups to use when performing Group Normalization. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. patch_size (`int`, *optional*, defaults to 2): The patch size to use in the patch embedding. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. use_linear_projection (int, *optional*): TODO: Not used only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used in each transformer block. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float() when performing the attention calculation. norm_type (`str`, *optional*, defaults to `"layer_norm"`): The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. block_type (`str`, *optional*, defaults to `"unidiffuser"`): The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard behavior in `diffusers`.) pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm (`pre_layer_norm = False`). norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. use_patch_pos_embed (`bool`, *optional*): Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = 2, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", block_type: str = "unidiffuser", pre_layer_norm: bool = False, norm_elementwise_affine: bool = True, use_patch_pos_embed=False, ff_final_dropout: bool = False, ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim # 1. Input # Only support patch input of shape (batch_size, num_channels, height, width) for now assert in_channels is not None and patch_size is not None, "Patch input requires in_channels and patch_size." assert sample_size is not None, "UTransformer2DModel over patched input must provide sample_size" # 2. Define input layers self.height = sample_size self.width = sample_size self.patch_size = patch_size self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, use_pos_embed=use_patch_pos_embed, ) # 3. Define transformers blocks # Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block, # and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in # a "U"-shaped fashion (e.g. first in_block to last out_block, etc.). # Quick hack to make the transformer block type configurable if block_type == "unidiffuser": block_cls = UniDiffuserBlock else: block_cls = UTransformerBlock self.transformer_in_blocks = nn.ModuleList( [ block_cls( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, final_dropout=ff_final_dropout, ) for d in range(num_layers // 2) ] ) self.transformer_mid_block = block_cls( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, final_dropout=ff_final_dropout, ) # For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs # before each transformer out_block. self.transformer_out_blocks = nn.ModuleList( [ nn.ModuleDict( { "skip": SkipBlock( inner_dim, ), "block": block_cls( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, final_dropout=ff_final_dropout, ), } ) for d in range(num_layers // 2) ] ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels # Following the UniDiffuser U-ViT implementation, we process the transformer output with # a LayerNorm layer with per-element affine params self.norm_out = nn.LayerNorm(inner_dim) def forward( self, hidden_states, encoder_hidden_states=None, timestep=None, class_labels=None, cross_attention_kwargs=None, return_dict: bool = True, hidden_states_is_embedding: bool = False, unpatchify: bool = True, ): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels conditioning. cross_attention_kwargs (*optional*): Keyword arguments to supply to the cross attention layers, if used. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the transformer blocks. unpatchify (`bool`, *optional*, defaults to `True`): Whether to unpatchify the transformer output. Returns: [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 0. Check inputs if not unpatchify and return_dict: raise ValueError( f"Cannot both define `unpatchify`: {unpatchify} and `return_dict`: {return_dict} since when" f" `unpatchify` is {unpatchify} the returned output is of shape (batch_size, seq_len, hidden_dim)" " rather than (batch_size, num_channels, height, width)." ) # 1. Input if not hidden_states_is_embedding: hidden_states = self.pos_embed(hidden_states) # 2. Blocks # In ("downsample") blocks skips = [] for in_block in self.transformer_in_blocks: hidden_states = in_block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) skips.append(hidden_states) # Mid block hidden_states = self.transformer_mid_block(hidden_states) # Out ("upsample") blocks for out_block in self.transformer_out_blocks: hidden_states = out_block["skip"](hidden_states, skips.pop()) hidden_states = out_block["block"]( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output # Don't support AdaLayerNorm for now, so no conditioning/scale/shift logic hidden_states = self.norm_out(hidden_states) # hidden_states = self.proj_out(hidden_states) if unpatchify: # unpatchify height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) else: output = hidden_states if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) class UniDiffuserModel(ModelMixin, ConfigMixin): """ Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details). Parameters: text_dim (`int`): The hidden dimension of the CLIP text model used to embed images. clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts. num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input. out_channels (`int`, *optional*): The number of output channels; if `None`, defaults to `in_channels`. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups to use when performing Group Normalization. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. patch_size (`int`, *optional*, defaults to 2): The patch size to use in the patch embedding. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. use_linear_projection (int, *optional*): TODO: Not used only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used in each transformer block. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float32 when performing the attention calculation. norm_type (`str`, *optional*, defaults to `"layer_norm"`): The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. block_type (`str`, *optional*, defaults to `"unidiffuser"`): The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard behavior in `diffusers`.) pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm (`pre_layer_norm = False`). norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. use_patch_pos_embed (`bool`, *optional*): Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). ff_final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. use_data_type_embedding (`bool`, *optional*): Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1 is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type` argument, which can either be `1` to use the weights trained on non-publically-available data or `0` otherwise. This argument is subsequently embedded by the data type embedding, if used. """ @register_to_config def __init__( self, text_dim: int = 768, clip_img_dim: int = 512, num_text_tokens: int = 77, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", block_type: str = "unidiffuser", pre_layer_norm: bool = False, use_timestep_embedding=False, norm_elementwise_affine: bool = True, use_patch_pos_embed=False, ff_final_dropout: bool = True, use_data_type_embedding: bool = False, ): super().__init__() # 0. Handle dimensions self.inner_dim = num_attention_heads * attention_head_dim assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size" self.sample_size = sample_size self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.patch_size = patch_size # Assume image is square... self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size) # 1. Define input layers # 1.1 Input layers for text and image input # For now, only support patch input for VAE latent image input self.vae_img_in = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=self.inner_dim, use_pos_embed=use_patch_pos_embed, ) self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim) self.text_in = nn.Linear(text_dim, self.inner_dim) # 1.2. Timestep embeddings for t_img, t_text self.timestep_img_proj = Timesteps( self.inner_dim, flip_sin_to_cos=True, downscale_freq_shift=0, ) self.timestep_img_embed = ( TimestepEmbedding( self.inner_dim, 4 * self.inner_dim, out_dim=self.inner_dim, ) if use_timestep_embedding else nn.Identity() ) self.timestep_text_proj = Timesteps( self.inner_dim, flip_sin_to_cos=True, downscale_freq_shift=0, ) self.timestep_text_embed = ( TimestepEmbedding( self.inner_dim, 4 * self.inner_dim, out_dim=self.inner_dim, ) if use_timestep_embedding else nn.Identity() ) # 1.3. Positional embedding self.num_text_tokens = num_text_tokens self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim)) self.pos_embed_drop = nn.Dropout(p=dropout) trunc_normal_(self.pos_embed, std=0.02) # 1.4. Handle data type token embeddings for UniDiffuser-V1, if necessary self.use_data_type_embedding = use_data_type_embedding if self.use_data_type_embedding: self.data_type_token_embedding = nn.Embedding(2, self.inner_dim) self.data_type_pos_embed_token = nn.Parameter(torch.zeros(1, 1, self.inner_dim)) # 2. Define transformer blocks self.transformer = UTransformer2DModel( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, in_channels=in_channels, out_channels=out_channels, num_layers=num_layers, dropout=dropout, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attention_bias=attention_bias, sample_size=sample_size, num_vector_embeds=num_vector_embeds, patch_size=patch_size, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, block_type=block_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, use_patch_pos_embed=use_patch_pos_embed, ff_final_dropout=ff_final_dropout, ) # 3. Define output layers patch_dim = (patch_size**2) * out_channels self.vae_img_out = nn.Linear(self.inner_dim, patch_dim) self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim) self.text_out = nn.Linear(self.inner_dim, text_dim) @torch.jit.ignore def no_weight_decay(self): return {"pos_embed"} def forward( self, latent_image_embeds: torch.Tensor, image_embeds: torch.Tensor, prompt_embeds: torch.Tensor, timestep_img: Union[torch.Tensor, float, int], timestep_text: Union[torch.Tensor, float, int], data_type: Optional[Union[torch.Tensor, float, int]] = 1, encoder_hidden_states=None, cross_attention_kwargs=None, ): """ Args: latent_image_embeds (`torch.Tensor` of shape `(batch size, latent channels, height, width)`): Latent image representation from the VAE encoder. image_embeds (`torch.Tensor` of shape `(batch size, 1, clip_img_dim)`): CLIP-embedded image representation (unsqueezed in the first dimension). prompt_embeds (`torch.Tensor` of shape `(batch size, seq_len, text_dim)`): CLIP-embedded text representation. timestep_img (`torch.long` or `float` or `int`): Current denoising step for the image. timestep_text (`torch.long` or `float` or `int`): Current denoising step for the text. data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`): Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data, or `0` otherwise. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. cross_attention_kwargs (*optional*): Keyword arguments to supply to the cross attention layers, if used. Returns: `tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text embedding. """ batch_size = latent_image_embeds.shape[0] # 1. Input # 1.1. Map inputs to shape (B, N, inner_dim) vae_hidden_states = self.vae_img_in(latent_image_embeds) clip_hidden_states = self.clip_img_in(image_embeds) text_hidden_states = self.text_in(prompt_embeds) num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1) # 1.2. Encode image timesteps to single token (B, 1, inner_dim) if not torch.is_tensor(timestep_img): timestep_img = torch.tensor([timestep_img], dtype=torch.long, device=vae_hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep_img = timestep_img * torch.ones(batch_size, dtype=timestep_img.dtype, device=timestep_img.device) timestep_img_token = self.timestep_img_proj(timestep_img) # t_img_token does not contain any weights and will always return f32 tensors # but time_embedding might be fp16, so we need to cast here. timestep_img_token = timestep_img_token.to(dtype=self.dtype) timestep_img_token = self.timestep_img_embed(timestep_img_token) timestep_img_token = timestep_img_token.unsqueeze(dim=1) # 1.3. Encode text timesteps to single token (B, 1, inner_dim) if not torch.is_tensor(timestep_text): timestep_text = torch.tensor([timestep_text], dtype=torch.long, device=vae_hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep_text = timestep_text * torch.ones(batch_size, dtype=timestep_text.dtype, device=timestep_text.device) timestep_text_token = self.timestep_text_proj(timestep_text) # t_text_token does not contain any weights and will always return f32 tensors # but time_embedding might be fp16, so we need to cast here. timestep_text_token = timestep_text_token.to(dtype=self.dtype) timestep_text_token = self.timestep_text_embed(timestep_text_token) timestep_text_token = timestep_text_token.unsqueeze(dim=1) # 1.4. Concatenate all of the embeddings together. if self.use_data_type_embedding: assert data_type is not None, "data_type must be supplied if the model uses a data type embedding" if not torch.is_tensor(data_type): data_type = torch.tensor([data_type], dtype=torch.int, device=vae_hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML data_type = data_type * torch.ones(batch_size, dtype=data_type.dtype, device=data_type.device) data_type_token = self.data_type_token_embedding(data_type).unsqueeze(dim=1) hidden_states = torch.cat( [ timestep_img_token, timestep_text_token, data_type_token, text_hidden_states, clip_hidden_states, vae_hidden_states, ], dim=1, ) else: hidden_states = torch.cat( [timestep_img_token, timestep_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states], dim=1, ) # 1.5. Prepare the positional embeddings and add to hidden states # Note: I think img_vae should always have the proper shape, so there's no need to interpolate # the position embeddings. if self.use_data_type_embedding: pos_embed = torch.cat( [self.pos_embed[:, : 1 + 1, :], self.data_type_pos_embed_token, self.pos_embed[:, 1 + 1 :, :]], dim=1 ) else: pos_embed = self.pos_embed hidden_states = hidden_states + pos_embed hidden_states = self.pos_embed_drop(hidden_states) # 2. Blocks hidden_states = self.transformer( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=None, class_labels=None, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, hidden_states_is_embedding=True, unpatchify=False, )[0] # 3. Output # Split out the predicted noise representation. if self.use_data_type_embedding: ( t_img_token_out, t_text_token_out, data_type_token_out, text_out, img_clip_out, img_vae_out, ) = hidden_states.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1) else: t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split( (1, 1, num_text_tokens, 1, num_img_tokens), dim=1 ) img_vae_out = self.vae_img_out(img_vae_out) # unpatchify height = width = int(img_vae_out.shape[1] ** 0.5) img_vae_out = img_vae_out.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out) img_vae_out = img_vae_out.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) img_clip_out = self.clip_img_out(img_clip_out) text_out = self.text_out(text_out) return img_vae_out, img_clip_out, text_out ================================================ FILE: src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py ================================================ import inspect from dataclasses import dataclass from typing import Callable, List, Optional, Union import numpy as np import PIL.Image import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, GPT2Tokenizer, ) from ...image_processor import VaeImageProcessor from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.outputs import BaseOutput from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .modeling_text_decoder import UniDiffuserTextDecoder from .modeling_uvit import UniDiffuserModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name # New BaseOutput child class for joint image-text output @dataclass class ImageTextPipelineOutput(BaseOutput): """ Output class for joint image-text pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. text (`List[str]` or `List[List[str]]`) List of generated text strings of length `batch_size` or a list of list of strings whose outer list has length `batch_size`. """ images: Optional[Union[List[PIL.Image.Image], np.ndarray]] text: Optional[Union[List[str], List[List[str]]]] class UniDiffuserPipeline(DiffusionPipeline): r""" Pipeline for a bimodal image-text model which supports unconditional text and image generation, text-conditioned image generation, image-conditioned text generation, and joint image-text generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. This is part of the UniDiffuser image representation along with the CLIP vision encoding. text_encoder ([`CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). image_encoder ([`CLIPVisionModel`]): A [`~transformers.CLIPVisionModel`] to encode images as part of its image representation along with the VAE latent representation. image_processor ([`CLIPImageProcessor`]): [`~transformers.CLIPImageProcessor`] to preprocess an image before CLIP encoding it with `image_encoder`. clip_tokenizer ([`CLIPTokenizer`]): A [`~transformers.CLIPTokenizer`] to tokenize the prompt before encoding it with `text_encoder`. text_decoder ([`UniDiffuserTextDecoder`]): Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser embedding. text_tokenizer ([`GPT2Tokenizer`]): A [`~transformers.GPT2Tokenizer`] to decode text for text generation; used along with the `text_decoder`. unet ([`UniDiffuserModel`]): A [U-ViT](https://github.com/baofff/U-ViT) model with UNNet-style skip connections between transformer layers to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler. """ # TODO: support for moving submodules for components with enable_model_cpu_offload model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae->text_decoder" def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, image_encoder: CLIPVisionModelWithProjection, clip_image_processor: CLIPImageProcessor, clip_tokenizer: CLIPTokenizer, text_decoder: UniDiffuserTextDecoder, text_tokenizer: GPT2Tokenizer, unet: UniDiffuserModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() if text_encoder.config.hidden_size != text_decoder.prefix_inner_dim: raise ValueError( f"The text encoder hidden size and text decoder prefix inner dim must be the same, but" f" `text_encoder.config.hidden_size`: {text_encoder.config.hidden_size} and `text_decoder.prefix_inner_dim`: {text_decoder.prefix_inner_dim}" ) self.register_modules( vae=vae, text_encoder=text_encoder, image_encoder=image_encoder, clip_image_processor=clip_image_processor, clip_tokenizer=clip_tokenizer, text_decoder=text_decoder, text_tokenizer=text_tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.num_channels_latents = vae.config.latent_channels self.text_encoder_seq_len = text_encoder.config.max_position_embeddings self.text_encoder_hidden_size = text_encoder.config.hidden_size self.image_encoder_projection_dim = image_encoder.config.projection_dim self.unet_resolution = unet.config.sample_size self.text_intermediate_dim = self.text_encoder_hidden_size if self.text_decoder.prefix_hidden_dim is not None: self.text_intermediate_dim = self.text_decoder.prefix_hidden_dim self.mode = None # TODO: handle safety checking? self.safety_checker = None # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents): r""" Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set mode will be used. """ prompt_available = (prompt is not None) or (prompt_embeds is not None) image_available = image is not None input_available = prompt_available or image_available prompt_latents_available = prompt_latents is not None vae_latents_available = vae_latents is not None clip_latents_available = clip_latents is not None full_latents_available = latents is not None image_latents_available = vae_latents_available and clip_latents_available all_indv_latents_available = prompt_latents_available and image_latents_available if self.mode is not None: # Preferentially use the mode set by the user mode = self.mode elif prompt_available: mode = "text2img" elif image_available: mode = "img2text" else: # Neither prompt nor image supplied, infer based on availability of latents if full_latents_available or all_indv_latents_available: mode = "joint" elif prompt_latents_available: mode = "text" elif image_latents_available: mode = "img" else: # No inputs or latents available mode = "joint" # Give warnings for ambiguous cases if self.mode is None and prompt_available and image_available: logger.warning( f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually," f" defaulting to mode '{mode}'." ) if self.mode is None and not input_available: if vae_latents_available != clip_latents_available: # Exactly one of vae_latents and clip_latents is supplied logger.warning( f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none" f" are expected to be supplied. Defaulting to mode '{mode}'." ) elif not prompt_latents_available and not vae_latents_available and not clip_latents_available: # No inputs or latents supplied logger.warning( f"No inputs or latents have been supplied, and mode has not been manually set," f" defaulting to mode '{mode}'." ) return mode # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() # Functions to manually set the mode def set_text_mode(self): r"""Manually set the generation mode to unconditional ("marginal") text generation.""" self.mode = "text" def set_image_mode(self): r"""Manually set the generation mode to unconditional ("marginal") image generation.""" self.mode = "img" def set_text_to_image_mode(self): r"""Manually set the generation mode to text-conditioned image generation.""" self.mode = "text2img" def set_image_to_text_mode(self): r"""Manually set the generation mode to image-conditioned text generation.""" self.mode = "img2text" def set_joint_mode(self): r"""Manually set the generation mode to unconditional joint image-text generation.""" self.mode = "joint" def reset_mode(self): r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs.""" self.mode = None def _infer_batch_size( self, mode, prompt, prompt_embeds, image, num_images_per_prompt, num_prompts_per_image, latents, prompt_latents, vae_latents, clip_latents, ): r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`.""" if num_images_per_prompt is None: num_images_per_prompt = 1 if num_prompts_per_image is None: num_prompts_per_image = 1 assert num_images_per_prompt > 0, "num_images_per_prompt must be a positive integer" assert num_prompts_per_image > 0, "num_prompts_per_image must be a positive integer" if mode in ["text2img"]: if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: # Either prompt or prompt_embeds must be present for text2img. batch_size = prompt_embeds.shape[0] multiplier = num_images_per_prompt elif mode in ["img2text"]: if isinstance(image, PIL.Image.Image): batch_size = 1 else: # Image must be available and type either PIL.Image.Image or torch.Tensor. # Not currently supporting something like image_embeds. batch_size = image.shape[0] multiplier = num_prompts_per_image elif mode in ["img"]: if vae_latents is not None: batch_size = vae_latents.shape[0] elif clip_latents is not None: batch_size = clip_latents.shape[0] else: batch_size = 1 multiplier = num_images_per_prompt elif mode in ["text"]: if prompt_latents is not None: batch_size = prompt_latents.shape[0] else: batch_size = 1 multiplier = num_prompts_per_image elif mode in ["joint"]: if latents is not None: batch_size = latents.shape[0] elif prompt_latents is not None: batch_size = prompt_latents.shape[0] elif vae_latents is not None: batch_size = vae_latents.shape[0] elif clip_latents is not None: batch_size = clip_latents.shape[0] else: batch_size = 1 if num_images_per_prompt == num_prompts_per_image: multiplier = num_images_per_prompt else: multiplier = min(num_images_per_prompt, num_prompts_per_image) logger.warning( f"You are using mode `{mode}` and `num_images_per_prompt`: {num_images_per_prompt} and" f" num_prompts_per_image: {num_prompts_per_image} are not equal. Using batch size equal to" f" `min(num_images_per_prompt, num_prompts_per_image) = {batch_size}." ) return batch_size, multiplier # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with self.tokenizer->self.clip_tokenizer def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.clip_tokenizer) text_inputs = self.clip_tokenizer( prompt, padding="max_length", max_length=self.clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.clip_tokenizer.batch_decode( untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.clip_tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.clip_tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents # Add num_prompts_per_image argument, sample from autoencoder moment distribution def encode_image_vae_latents( self, image, batch_size, num_prompts_per_image, dtype, device, do_classifier_free_guidance, generator=None, ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_prompts_per_image if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): image_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) * self.vae.config.scaling_factor for i in range(batch_size) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) # Scale image_latents by the VAE's scaling factor image_latents = image_latents * self.vae.config.scaling_factor if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) if do_classifier_free_guidance: uncond_image_latents = torch.zeros_like(image_latents) image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) return image_latents def encode_image_clip_latents( self, image, batch_size, num_prompts_per_image, dtype, device, generator=None, ): # Map image to CLIP embedding. if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) preprocessed_image = self.clip_image_processor.preprocess( image, return_tensors="pt", ) preprocessed_image = preprocessed_image.to(device=device, dtype=dtype) batch_size = batch_size * num_prompts_per_image if isinstance(generator, list): image_latents = [ self.image_encoder(**preprocessed_image[i : i + 1]).image_embeds for i in range(batch_size) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.image_encoder(**preprocessed_image).image_embeds if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) return image_latents def prepare_text_latents( self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None ): # Prepare latents for the CLIP embedded prompt. shape = (batch_size * num_images_per_prompt, seq_len, hidden_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: # latents is assumed to have shace (B, L, D) latents = latents.repeat(num_images_per_prompt, 1, 1) latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument. def prepare_image_vae_latents( self, batch_size, num_prompts_per_image, num_channels_latents, height, width, dtype, device, generator, latents=None, ): shape = ( batch_size * num_prompts_per_image, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: # latents is assumed to have shape (B, C, H, W) latents = latents.repeat(num_prompts_per_image, 1, 1, 1) latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_image_clip_latents( self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None ): # Prepare latents for the CLIP embedded image. shape = (batch_size * num_prompts_per_image, 1, clip_img_dim) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: # latents is assumed to have shape (B, L, D) latents = latents.repeat(num_prompts_per_image, 1, 1) latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def decode_text_latents(self, text_latents, device): output_token_list, seq_lengths = self.text_decoder.generate_captions( text_latents, self.text_tokenizer.eos_token_id, device=device ) output_list = output_token_list.cpu().numpy() generated_text = [ self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) for output, length in zip(output_list, seq_lengths) ] return generated_text def _split(self, x, height, width): r""" Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) and (B, 1, clip_img_dim) """ batch_size = x.shape[0] latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor img_vae_dim = self.num_channels_latents * latent_height * latent_width img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_projection_dim], dim=1) img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) return img_vae, img_clip def _combine(self, img_vae, img_clip): r""" Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1, clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim). """ img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) return torch.concat([img_vae, img_clip], dim=-1) def _split_joint(self, x, height, width): r""" Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim + text_seq_len * text_dim] into (img_vae, img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is of shape (B, text_seq_len, text_dim). """ batch_size = x.shape[0] latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor img_vae_dim = self.num_channels_latents * latent_height * latent_width text_dim = self.text_encoder_seq_len * self.text_intermediate_dim img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1) img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim)) return img_vae, img_clip, text def _combine_joint(self, img_vae, img_clip, text): r""" Combines a latent image img_vae of shape (B, C, H, W), a CLIP-embedded image img_clip of shape (B, L_img, clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B, C * H * W + L_img * clip_img_dim + L_text * text_dim). """ img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) text = torch.reshape(text, (text.shape[0], -1)) return torch.concat([img_vae, img_clip, text], dim=-1) def _get_noise_pred( self, mode, latents, t, prompt_embeds, img_vae, img_clip, max_timestep, data_type, guidance_scale, generator, device, height, width, ): r""" Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary. """ if mode == "joint": # Joint text-image generation img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width) img_vae_out, img_clip_out, text_out = self.unet( img_vae_latents, img_clip_latents, text_latents, timestep_img=t, timestep_text=t, data_type=data_type ) x_out = self._combine_joint(img_vae_out, img_clip_out, text_out) if guidance_scale <= 1.0: return x_out # Classifier-free guidance img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) _, _, text_out_uncond = self.unet( img_vae_T, img_clip_T, text_latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type ) img_vae_out_uncond, img_clip_out_uncond, _ = self.unet( img_vae_latents, img_clip_latents, text_T, timestep_img=t, timestep_text=max_timestep, data_type=data_type, ) x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond) return guidance_scale * x_out + (1.0 - guidance_scale) * x_out_uncond elif mode == "text2img": # Text-conditioned image generation img_vae_latents, img_clip_latents = self._split(latents, height, width) img_vae_out, img_clip_out, text_out = self.unet( img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=0, data_type=data_type ) img_out = self._combine(img_vae_out, img_clip_out) if guidance_scale <= 1.0: return img_out # Classifier-free guidance text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( img_vae_latents, img_clip_latents, text_T, timestep_img=t, timestep_text=max_timestep, data_type=data_type, ) img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond) return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond elif mode == "img2text": # Image-conditioned text generation img_vae_out, img_clip_out, text_out = self.unet( img_vae, img_clip, latents, timestep_img=0, timestep_text=t, data_type=data_type ) if guidance_scale <= 1.0: return text_out # Classifier-free guidance img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( img_vae_T, img_clip_T, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type ) return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond elif mode == "text": # Unconditional ("marginal") text generation (no CFG) img_vae_out, img_clip_out, text_out = self.unet( img_vae, img_clip, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type ) return text_out elif mode == "img": # Unconditional ("marginal") image generation (no CFG) img_vae_latents, img_clip_latents = self._split(latents, height, width) img_vae_out, img_clip_out, text_out = self.unet( img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=max_timestep, data_type=data_type, ) img_out = self._combine(img_vae_out, img_clip_out) return img_out def check_latents_shape(self, latents_name, latents, expected_shape): latents_shape = latents.shape expected_num_dims = len(expected_shape) + 1 # expected dimensions plus the batch dimension expected_shape_str = ", ".join(str(dim) for dim in expected_shape) if len(latents_shape) != expected_num_dims: raise ValueError( f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" f" {latents_shape} has {len(latents_shape)} dimensions." ) for i in range(1, expected_num_dims): if latents_shape[i] != expected_shape[i - 1]: raise ValueError( f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" f" {latents_shape} has {latents_shape[i]} != {expected_shape[i - 1]} at dimension {i}." ) def check_inputs( self, mode, prompt, image, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, latents=None, prompt_latents=None, vae_latents=None, clip_latents=None, ): # Check inputs before running the generative process. if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: raise ValueError( f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if mode == "text2img": if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if mode == "img2text": if image is None: raise ValueError("`img2text` mode requires an image to be provided.") # Check provided latents latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor full_latents_available = latents is not None prompt_latents_available = prompt_latents is not None vae_latents_available = vae_latents is not None clip_latents_available = clip_latents is not None if full_latents_available: individual_latents_available = ( prompt_latents is not None or vae_latents is not None or clip_latents is not None ) if individual_latents_available: logger.warning( "You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and" " `clip_latents`. The value of `latents` will override the value of any individually supplied latents." ) # Check shape of full latents img_vae_dim = self.num_channels_latents * latent_height * latent_width text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size latents_dim = img_vae_dim + self.image_encoder_projection_dim + text_dim latents_expected_shape = (latents_dim,) self.check_latents_shape("latents", latents, latents_expected_shape) # Check individual latent shapes, if present if prompt_latents_available: prompt_latents_expected_shape = (self.text_encoder_seq_len, self.text_encoder_hidden_size) self.check_latents_shape("prompt_latents", prompt_latents, prompt_latents_expected_shape) if vae_latents_available: vae_latents_expected_shape = (self.num_channels_latents, latent_height, latent_width) self.check_latents_shape("vae_latents", vae_latents, vae_latents_expected_shape) if clip_latents_available: clip_latents_expected_shape = (1, self.image_encoder_projection_dim) self.check_latents_shape("clip_latents", clip_latents, clip_latents_expected_shape) if mode in ["text2img", "img"] and vae_latents_available and clip_latents_available: if vae_latents.shape[0] != clip_latents.shape[0]: raise ValueError( f"Both `vae_latents` and `clip_latents` are supplied, but their batch dimensions are not equal:" f" {vae_latents.shape[0]} != {clip_latents.shape[0]}." ) if mode == "joint" and prompt_latents_available and vae_latents_available and clip_latents_available: if prompt_latents.shape[0] != vae_latents.shape[0] or prompt_latents.shape[0] != clip_latents.shape[0]: raise ValueError( f"All of `prompt_latents`, `vae_latents`, and `clip_latents` are supplied, but their batch" f" dimensions are not equal: {prompt_latents.shape[0]} != {vae_latents.shape[0]}" f" != {clip_latents.shape[0]}." ) @torch.no_grad() def __call__( self, prompt: Optional[Union[str, List[str]]] = None, image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None, height: Optional[int] = None, width: Optional[int] = None, data_type: Optional[int] = 1, num_inference_steps: int = 50, guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, num_prompts_per_image: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_latents: Optional[torch.Tensor] = None, vae_latents: Optional[torch.Tensor] = None, clip_latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. Required for text-conditioned image generation (`text2img`) mode. image (`torch.Tensor` or `PIL.Image.Image`, *optional*): `Image` or tensor representing an image batch. Required for image-conditioned text generation (`img2text`) mode. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. data_type (`int`, *optional*, defaults to 1): The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type embedding; this is added for compatibility with the [UniDiffuser-v1](https://huggingface.co/thu-ml/unidiffuser-v1) checkpoint. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 8.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). Used in text-conditioned image generation (`text2img`) mode. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples are generated. num_prompts_per_image (`int`, *optional*, defaults to 1): The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) and `text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples are generated. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for joint image-text generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. This assumes a full set of VAE, CLIP, and text latents, if supplied, overrides the value of `prompt_latents`, `vae_latents`, and `clip_latents`. prompt_latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for text generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. vae_latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. clip_latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. Used in text-conditioned image generation (`text2img`) mode. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are be generated from the `negative_prompt` input argument. Used in text-conditioned image generation (`text2img`) mode. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImageTextPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Returns: [`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.unidiffuser.ImageTextPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of generated texts. """ # 0. Default height and width to unet height = height or self.unet_resolution * self.vae_scale_factor width = width or self.unet_resolution * self.vae_scale_factor # 1. Check inputs # Recalculate mode for each call to the pipeline. mode = self._infer_mode(prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents) self.check_inputs( mode, prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, latents, prompt_latents, vae_latents, clip_latents, ) # 2. Define call parameters batch_size, multiplier = self._infer_batch_size( mode, prompt, prompt_embeds, image, num_images_per_prompt, num_prompts_per_image, latents, prompt_latents, vae_latents, clip_latents, ) device = self._execution_device reduce_text_emb_dim = self.text_intermediate_dim < self.text_encoder_hidden_size or self.mode != "text2img" # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. # Note that this differs from the formulation in the unidiffusers paper! do_classifier_free_guidance = guidance_scale > 1.0 # check if scheduler is in sigmas space # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") # 3. Encode input prompt, if available; otherwise prepare text latents if latents is not None: # Overwrite individual latents vae_latents, clip_latents, prompt_latents = self._split_joint(latents, height, width) if mode in ["text2img"]: # 3.1. Encode input prompt, if available assert prompt is not None or prompt_embeds is not None prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=multiplier, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # if do_classifier_free_guidance: # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) else: # 3.2. Prepare text latent variables, if input not available prompt_embeds = self.prepare_text_latents( batch_size=batch_size, num_images_per_prompt=multiplier, seq_len=self.text_encoder_seq_len, hidden_size=self.text_encoder_hidden_size, dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision device=device, generator=generator, latents=prompt_latents, ) if reduce_text_emb_dim: prompt_embeds = self.text_decoder.encode(prompt_embeds) # 4. Encode image, if available; otherwise prepare image latents if mode in ["img2text"]: # 4.1. Encode images, if available assert image is not None, "`img2text` requires a conditioning image" # Encode image using VAE image_vae = self.image_processor.preprocess(image) height, width = image_vae.shape[-2:] image_vae_latents = self.encode_image_vae_latents( image=image_vae, batch_size=batch_size, num_prompts_per_image=multiplier, dtype=prompt_embeds.dtype, device=device, do_classifier_free_guidance=False, # Copied from InstructPix2Pix, don't use their version of CFG generator=generator, ) # Encode image using CLIP image_clip_latents = self.encode_image_clip_latents( image=image, batch_size=batch_size, num_prompts_per_image=multiplier, dtype=prompt_embeds.dtype, device=device, generator=generator, ) # (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size) image_clip_latents = image_clip_latents.unsqueeze(1) else: # 4.2. Prepare image latent variables, if input not available # Prepare image VAE latents in latent space image_vae_latents = self.prepare_image_vae_latents( batch_size=batch_size, num_prompts_per_image=multiplier, num_channels_latents=self.num_channels_latents, height=height, width=width, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=vae_latents, ) # Prepare image CLIP latents image_clip_latents = self.prepare_image_clip_latents( batch_size=batch_size, num_prompts_per_image=multiplier, clip_img_dim=self.image_encoder_projection_dim, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=clip_latents, ) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # max_timestep = timesteps[0] max_timestep = self.scheduler.config.num_train_timesteps # 6. Prepare latent variables if mode == "joint": latents = self._combine_joint(image_vae_latents, image_clip_latents, prompt_embeds) elif mode in ["text2img", "img"]: latents = self._combine(image_vae_latents, image_clip_latents) elif mode in ["img2text", "text"]: latents = prompt_embeds # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) logger.debug(f"Scheduler extra step kwargs: {extra_step_kwargs}") # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # predict the noise residual # Also applies classifier-free guidance as described in the UniDiffuser paper noise_pred = self._get_noise_pred( mode, latents, t, prompt_embeds, image_vae_latents, image_clip_latents, max_timestep, data_type, guidance_scale, generator, device, height, width, ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 9. Post-processing image = None text = None if mode == "joint": image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width) if not output_type == "latent": # Map latent VAE image back to pixel space image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = image_vae_latents text = self.decode_text_latents(text_latents, device) elif mode in ["text2img", "img"]: image_vae_latents, image_clip_latents = self._split(latents, height, width) if not output_type == "latent": # Map latent VAE image back to pixel space image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = image_vae_latents elif mode in ["img2text", "text"]: text_latents = latents text = self.decode_text_latents(text_latents, device) self.maybe_free_model_hooks() # 10. Postprocess the image, if necessary if image is not None: do_denormalize = [True] * image.shape[0] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, text) return ImageTextPipelineOutput(images=image, text=text) ================================================ FILE: src/diffusers/pipelines/wuerstchen/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["modeling_paella_vq_model"] = ["PaellaVQModel"] _import_structure["modeling_wuerstchen_diffnext"] = ["WuerstchenDiffNeXt"] _import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"] _import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"] _import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"] _import_structure["pipeline_wuerstchen_prior"] = ["DEFAULT_STAGE_C_TIMESTEPS", "WuerstchenPriorPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .modeling_paella_vq_model import PaellaVQModel from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt from .modeling_wuerstchen_prior import WuerstchenPrior from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline from .pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPriorPipeline else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py ================================================ # Copyright (c) 2022 Dominic Rampas MIT License # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...models.autoencoders.vae import DecoderOutput, VectorQuantizer from ...models.modeling_utils import ModelMixin from ...models.vq_model import VQEncoderOutput from ...utils.accelerate_utils import apply_forward_hook class MixingResidualBlock(nn.Module): """ Residual block with mixing used by Paella's VQ-VAE. """ def __init__(self, inp_channels, embed_dim): super().__init__() # depthwise self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) self.depthwise = nn.Sequential( nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels) ) # channelwise self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels) ) self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) def forward(self, x): mods = self.gammas x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1] x = x + self.depthwise(x_temp) * mods[2] x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4] x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] return x class PaellaVQModel(ModelMixin, ConfigMixin): r"""VQ-VAE model from Paella model. This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the model (such as downloading or saving, etc.) Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image. levels (int, *optional*, defaults to 2): Number of levels in the model. bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model. embed_dim (int, *optional*, defaults to 384): Number of hidden channels in the model. latent_channels (int, *optional*, defaults to 4): Number of latent channels in the VQ-VAE model. num_vq_embeddings (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE. scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space. """ @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, up_down_scale_factor: int = 2, levels: int = 2, bottleneck_blocks: int = 12, embed_dim: int = 384, latent_channels: int = 4, num_vq_embeddings: int = 8192, scale_factor: float = 0.3764, ): super().__init__() c_levels = [embed_dim // (2**i) for i in reversed(range(levels))] # Encoder blocks self.in_block = nn.Sequential( nn.PixelUnshuffle(up_down_scale_factor), nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1), ) down_blocks = [] for i in range(levels): if i > 0: down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) block = MixingResidualBlock(c_levels[i], c_levels[i] * 4) down_blocks.append(block) down_blocks.append( nn.Sequential( nn.Conv2d(c_levels[-1], latent_channels, kernel_size=1, bias=False), nn.BatchNorm2d(latent_channels), # then normalize them to have mean 0 and std 1 ) ) self.down_blocks = nn.Sequential(*down_blocks) # Vector Quantizer self.vquantizer = VectorQuantizer(num_vq_embeddings, vq_embed_dim=latent_channels, legacy=False, beta=0.25) # Decoder blocks up_blocks = [nn.Sequential(nn.Conv2d(latent_channels, c_levels[-1], kernel_size=1))] for i in range(levels): for j in range(bottleneck_blocks if i == 0 else 1): block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) up_blocks.append(block) if i < levels - 1: up_blocks.append( nn.ConvTranspose2d( c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1 ) ) self.up_blocks = nn.Sequential(*up_blocks) self.out_block = nn.Sequential( nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1), nn.PixelShuffle(up_down_scale_factor), ) @apply_forward_hook def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput: h = self.in_block(x) h = self.down_blocks(h) if not return_dict: return (h,) return VQEncoderOutput(latents=h) @apply_forward_hook def decode( self, h: torch.Tensor, force_not_quantize: bool = True, return_dict: bool = True ) -> Union[DecoderOutput, torch.Tensor]: if not force_not_quantize: quant, _, _ = self.vquantizer(h) else: quant = h x = self.up_blocks(quant) dec = self.out_block(x) if not return_dict: return (dec,) return DecoderOutput(sample=dec) def forward(self, sample: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" Args: sample (`torch.Tensor`): Input sample. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ x = sample h = self.encode(x).latents dec = self.decode(h).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec) ================================================ FILE: src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py ================================================ import torch import torch.nn as nn from ...models.attention_processor import Attention class WuerstchenLayerNorm(nn.LayerNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x): x = x.permute(0, 2, 3, 1) x = super().forward(x) return x.permute(0, 3, 1, 2) class TimestepBlock(nn.Module): def __init__(self, c, c_timestep): super().__init__() self.mapper = nn.Linear(c_timestep, c * 2) def forward(self, x, t): a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) return x * (1 + a) + b class ResBlock(nn.Module): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): super().__init__() self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) ) def forward(self, x, x_skip=None): x_res = x if x_skip is not None: x = torch.cat([x, x_skip], dim=1) x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) x = self.channelwise(x).permute(0, 3, 1, 2) return x + x_res # from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 class GlobalResponseNorm(nn.Module): def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * stand_div_norm) + self.beta + x class AttnBlock(nn.Module): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() self.self_attn = self_attn self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) def forward(self, x, kv): kv = self.kv_mapper(kv) norm_x = self.norm(x) if self.self_attn: batch_size, channel, _, _ = x.shape kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1) x = x + self.attention(norm_x, encoder_hidden_states=kv) return x ================================================ FILE: src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py ================================================ # Copyright (c) 2023 Dominic Rampas MIT License # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import numpy as np import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin from .modeling_wuerstchen_common import AttnBlock, GlobalResponseNorm, TimestepBlock, WuerstchenLayerNorm class WuerstchenDiffNeXt(ModelMixin, ConfigMixin): @register_to_config def __init__( self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1024, c_hidden=[320, 640, 1280, 1280], nhead=[-1, 10, 20, 20], blocks=[4, 4, 14, 4], level_config=["CT", "CTA", "CTA", "CTA"], inject_effnet=[False, True, True, True], effnet_embd=16, clip_embd=1024, kernel_size=3, dropout=0.1, ): super().__init__() self.c_r = c_r self.c_cond = c_cond if not isinstance(dropout, list): dropout = [dropout] * len(c_hidden) # CONDITIONING self.clip_mapper = nn.Linear(clip_embd, c_cond) self.effnet_mappers = nn.ModuleList( [ nn.Conv2d(effnet_embd, c_cond, kernel_size=1) if inject else None for inject in inject_effnet + list(reversed(inject_effnet)) ] ) self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) self.embedding = nn.Sequential( nn.PixelUnshuffle(patch_size), nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), ) def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): if block_type == "C": return ResBlockStageB(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) elif block_type == "A": return AttnBlock(c_hidden, c_cond, nhead, self_attn=True, dropout=dropout) elif block_type == "T": return TimestepBlock(c_hidden, c_r) else: raise ValueError(f"Block type {block_type} not supported") # BLOCKS # -- down blocks self.down_blocks = nn.ModuleList() for i in range(len(c_hidden)): down_block = nn.ModuleList() if i > 0: down_block.append( nn.Sequential( WuerstchenLayerNorm(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), ) ) for _ in range(blocks[i]): for block_type in level_config[i]: c_skip = c_cond if inject_effnet[i] else 0 down_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) self.down_blocks.append(down_block) # -- up blocks self.up_blocks = nn.ModuleList() for i in reversed(range(len(c_hidden))): up_block = nn.ModuleList() for j in range(blocks[i]): for k, block_type in enumerate(level_config[i]): c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 c_skip += c_cond if inject_effnet[i] else 0 up_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) if i > 0: up_block.append( nn.Sequential( WuerstchenLayerNorm(c_hidden[i], elementwise_affine=False, eps=1e-6), nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), ) ) self.up_blocks.append(up_block) # OUTPUT self.clf = nn.Sequential( WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), nn.Conv2d(c_hidden[0], 2 * c_out * (patch_size**2), kernel_size=1), nn.PixelShuffle(patch_size), ) # --- WEIGHT INIT --- self.apply(self._init_weights) def _init_weights(self, m): # General init if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) for mapper in self.effnet_mappers: if mapper is not None: nn.init.normal_(mapper.weight, std=0.02) # conditionings nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs nn.init.constant_(self.clf[1].weight, 0) # outputs # blocks for level_block in self.down_blocks + self.up_blocks: for block in level_block: if isinstance(block, ResBlockStageB): block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks)) elif isinstance(block, TimestepBlock): nn.init.constant_(block.mapper.weight, 0) def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 emb = math.log(max_positions) / (half_dim - 1) emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() emb = r[:, None] * emb[None, :] emb = torch.cat([emb.sin(), emb.cos()], dim=1) if self.c_r % 2 == 1: # zero pad emb = nn.functional.pad(emb, (0, 1), mode="constant") return emb.to(dtype=r.dtype) def gen_c_embeddings(self, clip): clip = self.clip_mapper(clip) clip = self.seq_norm(clip) return clip def _down_encode(self, x, r_embed, effnet, clip=None): level_outputs = [] for i, down_block in enumerate(self.down_blocks): effnet_c = None for block in down_block: if isinstance(block, ResBlockStageB): if effnet_c is None and self.effnet_mappers[i] is not None: dtype = effnet.dtype effnet_c = self.effnet_mappers[i]( nn.functional.interpolate( effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True ).to(dtype) ) skip = effnet_c if self.effnet_mappers[i] is not None else None x = block(x, skip) elif isinstance(block, AttnBlock): x = block(x, clip) elif isinstance(block, TimestepBlock): x = block(x, r_embed) else: x = block(x) level_outputs.insert(0, x) return level_outputs def _up_decode(self, level_outputs, r_embed, effnet, clip=None): x = level_outputs[0] for i, up_block in enumerate(self.up_blocks): effnet_c = None for j, block in enumerate(up_block): if isinstance(block, ResBlockStageB): if effnet_c is None and self.effnet_mappers[len(self.down_blocks) + i] is not None: dtype = effnet.dtype effnet_c = self.effnet_mappers[len(self.down_blocks) + i]( nn.functional.interpolate( effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True ).to(dtype) ) skip = level_outputs[i] if j == 0 and i > 0 else None if effnet_c is not None: if skip is not None: skip = torch.cat([skip, effnet_c], dim=1) else: skip = effnet_c x = block(x, skip) elif isinstance(block, AttnBlock): x = block(x, clip) elif isinstance(block, TimestepBlock): x = block(x, r_embed) else: x = block(x) return x def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=True): if x_cat is not None: x = torch.cat([x, x_cat], dim=1) # Process the conditioning embeddings r_embed = self.gen_r_embedding(r) if clip is not None: clip = self.gen_c_embeddings(clip) # Model Blocks x_in = x x = self.embedding(x) level_outputs = self._down_encode(x, r_embed, effnet, clip) x = self._up_decode(level_outputs, r_embed, effnet, clip) a, b = self.clf(x).chunk(2, dim=1) b = b.sigmoid() * (1 - eps * 2) + eps if return_noise: return (x_in - a) / b else: return a, b class ResBlockStageB(nn.Module): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): super().__init__() self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( nn.Linear(c + c_skip, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c), ) def forward(self, x, x_skip=None): x_res = x x = self.norm(self.depthwise(x)) if x_skip is not None: x = torch.cat([x, x_skip], dim=1) x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x + x_res ================================================ FILE: src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py ================================================ # Copyright (c) 2023 Dominic Rampas MIT License # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Dict, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) from ...models.modeling_utils import ModelMixin from ...utils import is_torch_version from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): unet_name = "prior" _supports_gradient_checkpointing = True @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() self.c_r = c_r self.projection = nn.Conv2d(c_in, c, kernel_size=1) self.cond_mapper = nn.Sequential( nn.Linear(c_cond, c), nn.LeakyReLU(0.2), nn.Linear(c, c), ) self.blocks = nn.ModuleList() for _ in range(depth): self.blocks.append(ResBlock(c, dropout=dropout)) self.blocks.append(TimestepBlock(c, c_r)) self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) self.out = nn.Sequential( WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6), nn.Conv2d(c, c_in * 2, kernel_size=1), ) self.gradient_checkpointing = False self.set_default_attn_processor() @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 emb = math.log(max_positions) / (half_dim - 1) emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() emb = r[:, None] * emb[None, :] emb = torch.cat([emb.sin(), emb.cos()], dim=1) if self.c_r % 2 == 1: # zero pad emb = nn.functional.pad(emb, (0, 1), mode="constant") return emb.to(dtype=r.dtype) def forward(self, x, r, c): x_in = x x = self.projection(x) c_embed = self.cond_mapper(c) r_embed = self.gen_r_embedding(r) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): for block in self.blocks: if isinstance(block, AttnBlock): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, c_embed, use_reentrant=False ) elif isinstance(block, TimestepBlock): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, r_embed, use_reentrant=False ) else: x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) else: for block in self.blocks: if isinstance(block, AttnBlock): x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed) elif isinstance(block, TimestepBlock): x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed) else: x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x) else: for block in self.blocks: if isinstance(block, AttnBlock): x = block(x, c_embed) elif isinstance(block, TimestepBlock): x = block(x, r_embed) else: x = block(x) a, b = self.out(x).chunk(2, dim=1) return (x_in - a) / ((1 - b).abs() + 1e-5) ================================================ FILE: src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Dict, List, Optional, Union import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler from ...utils import deprecate, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_paella_vq_model import PaellaVQModel from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 ... ).to("cuda") >>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain("warp-ai/wuerstchen", torch_dtype=torch.float16).to( ... "cuda" ... ) >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prior_output = pipe(prompt) >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt) ``` """ class WuerstchenDecoderPipeline(DiffusionPipeline): """ Pipeline for generating images from the Wuerstchen model. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: tokenizer (`CLIPTokenizer`): The CLIP tokenizer. text_encoder (`CLIPTextModel`): The CLIP text encoder. decoder ([`WuerstchenDiffNeXt`]): The WuerstchenDiffNeXt unet decoder. vqgan ([`PaellaVQModel`]): The VQGAN model. scheduler ([`DDPMWuerstchenScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. latent_dim_scale (float, `optional`, defaults to 10.67): Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and width=int(24*10.67)=256 in order to match the training conditions. """ model_cpu_offload_seq = "text_encoder->decoder->vqgan" _callback_tensor_inputs = [ "latents", "text_encoder_hidden_states", "negative_prompt_embeds", "image_embeddings", ] def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, decoder: WuerstchenDiffNeXt, scheduler: DDPMWuerstchenScheduler, vqgan: PaellaVQModel, latent_dim_scale: float = 10.67, ) -> None: super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, decoder=decoder, scheduler=scheduler, vqgan=vqgan, ) self.register_to_config(latent_dim_scale=latent_dim_scale) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] attention_mask = attention_mask[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device)) text_encoder_hidden_states = text_encoder_output.last_hidden_state text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_text_encoder_hidden_states = None if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds_text_encoder_output = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) ) uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes return text_encoder_hidden_states, uncond_text_encoder_hidden_states @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeddings: Union[torch.Tensor, List[torch.Tensor]], prompt: Union[str, List[str]] = None, num_inference_steps: int = 12, timesteps: Optional[List[float]] = None, guidance_scale: float = 0.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): """ Function invoked when calling the pipeline for generation. Args: image_embedding (`torch.Tensor` or `List[torch.Tensor]`): Image Embeddings either extracted from an image or generated by a Prior Model. prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. num_inference_steps (`int`, *optional*, defaults to 12): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 0.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `decoder_guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image embeddings. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) # 0. Define commonly used variables device = self._execution_device dtype = self.decoder.dtype self._guidance_scale = guidance_scale # 1. Check inputs. Raise error if not correct if not isinstance(prompt, list): if isinstance(prompt, str): prompt = [prompt] else: raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") if self.do_classifier_free_guidance: if negative_prompt is not None and not isinstance(negative_prompt, list): if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] else: raise TypeError( f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." ) if isinstance(image_embeddings, list): image_embeddings = torch.cat(image_embeddings, dim=0) if isinstance(image_embeddings, np.ndarray): image_embeddings = torch.Tensor(image_embeddings, device=device).to(dtype=dtype) if not isinstance(image_embeddings, torch.Tensor): raise TypeError( f"'image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(image_embeddings)}." ) if not isinstance(num_inference_steps, int): raise TypeError( f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ In Case you want to provide explicit timesteps, please use the 'timesteps' argument." ) # 2. Encode caption prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, image_embeddings.size(0) * num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, ) text_encoder_hidden_states = ( torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds ) effnet = ( torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) if self.do_classifier_free_guidance else image_embeddings ) # 3. Determine latent shape of latents latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale) latent_width = int(image_embeddings.size(3) * self.config.latent_dim_scale) latent_features_shape = (image_embeddings.size(0) * num_images_per_prompt, 4, latent_height, latent_width) # 4. Prepare and set timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latents latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler) # 6. Run denoising loop self._num_timesteps = len(timesteps[:-1]) for i, t in enumerate(self.progress_bar(timesteps[:-1])): ratio = t.expand(latents.size(0)).to(dtype) # 7. Denoise latents predicted_latents = self.decoder( torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio, effnet=effnet, clip=text_encoder_hidden_states, ) # 8. Check for classifier free guidance and apply it if self.do_classifier_free_guidance: predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale) # 9. Renoise latents to next timestep latents = self.scheduler.step( model_output=predicted_latents, timestep=ratio, sample=latents, generator=generator, ).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) image_embeddings = callback_outputs.pop("image_embeddings", image_embeddings) text_encoder_hidden_states = callback_outputs.pop( "text_encoder_hidden_states", text_encoder_hidden_states ) if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" ) if not output_type == "latent": # 10. Scale and decode the image latents with vq-vae latents = self.vqgan.config.scale_factor * latents images = self.vqgan.decode(latents).sample.clamp(0, 1) if output_type == "np": images = images.permute(0, 2, 3, 1).cpu().float().numpy() elif output_type == "pil": images = images.permute(0, 2, 3, 1).cpu().float().numpy() images = self.numpy_to_pil(images) else: images = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return images return ImagePipelineOutput(images) ================================================ FILE: src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Dict, List, Optional, Union import torch from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler from ...utils import deprecate, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from .modeling_paella_vq_model import PaellaVQModel from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt from .modeling_wuerstchen_prior import WuerstchenPrior from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline TEXT2IMAGE_EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusions import WuerstchenCombinedPipeline >>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-ai/Wuerstchen", torch_dtype=torch.float16).to( ... "cuda" ... ) >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> images = pipe(prompt=prompt) ``` """ class WuerstchenCombinedPipeline(DiffusionPipeline): """ Combined Pipeline for text-to-image generation using Wuerstchen This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: tokenizer (`CLIPTokenizer`): The decoder tokenizer to be used for text inputs. text_encoder (`CLIPTextModel`): The decoder text encoder to be used for text inputs. decoder (`WuerstchenDiffNeXt`): The decoder model to be used for decoder image generation pipeline. scheduler (`DDPMWuerstchenScheduler`): The scheduler to be used for decoder image generation pipeline. vqgan (`PaellaVQModel`): The VQGAN model to be used for decoder image generation pipeline. prior_tokenizer (`CLIPTokenizer`): The prior tokenizer to be used for text inputs. prior_text_encoder (`CLIPTextModel`): The prior text encoder to be used for text inputs. prior_prior (`WuerstchenPrior`): The prior model to be used for prior pipeline. prior_scheduler (`DDPMWuerstchenScheduler`): The scheduler to be used for prior pipeline. """ _load_connected_pipes = True def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, decoder: WuerstchenDiffNeXt, scheduler: DDPMWuerstchenScheduler, vqgan: PaellaVQModel, prior_tokenizer: CLIPTokenizer, prior_text_encoder: CLIPTextModel, prior_prior: WuerstchenPrior, prior_scheduler: DDPMWuerstchenScheduler, ): super().__init__() self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, decoder=decoder, scheduler=scheduler, vqgan=vqgan, prior_prior=prior_prior, prior_text_encoder=prior_text_encoder, prior_tokenizer=prior_tokenizer, prior_scheduler=prior_scheduler, ) self.prior_pipe = WuerstchenPriorPipeline( prior=prior_prior, text_encoder=prior_text_encoder, tokenizer=prior_tokenizer, scheduler=prior_scheduler, ) self.decoder_pipe = WuerstchenDecoderPipeline( text_encoder=text_encoder, tokenizer=tokenizer, decoder=decoder, scheduler=scheduler, vqgan=vqgan, ) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. """ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) def set_progress_bar_config(self, **kwargs): self.prior_pipe.set_progress_bar_config(**kwargs) self.decoder_pipe.set_progress_bar_config(**kwargs) @torch.no_grad() @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, prior_num_inference_steps: int = 60, prior_timesteps: Optional[List[float]] = None, prior_guidance_scale: float = 4.0, num_inference_steps: int = 12, decoder_timesteps: Optional[List[float]] = None, decoder_guidance_scale: float = 0.0, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation for the prior and decoder. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60): The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. For more specific timestep spacing, you can pass customized `prior_timesteps` num_inference_steps (`int`, *optional*, defaults to 12): The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. For more specific timestep spacing, you can pass customized `timesteps` prior_timesteps (`List[float]`, *optional*): Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced `prior_num_inference_steps` timesteps are used. Must be in descending order. decoder_timesteps (`List[float]`, *optional*): Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. decoder_guidance_scale (`float`, *optional*, defaults to 0.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. prior_callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. prior_callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ prior_kwargs = {} if kwargs.get("prior_callback", None) is not None: prior_kwargs["callback"] = kwargs.pop("prior_callback") deprecate( "prior_callback", "1.0.0", "Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", ) if kwargs.get("prior_callback_steps", None) is not None: deprecate( "prior_callback_steps", "1.0.0", "Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", ) prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps") prior_outputs = self.prior_pipe( prompt=prompt if prompt_embeds is None else None, height=height, width=width, num_inference_steps=prior_num_inference_steps, timesteps=prior_timesteps, guidance_scale=prior_guidance_scale, negative_prompt=negative_prompt if negative_prompt_embeds is None else None, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, output_type="pt", return_dict=False, callback_on_step_end=prior_callback_on_step_end, callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, **prior_kwargs, ) image_embeddings = prior_outputs[0] outputs = self.decoder_pipe( image_embeddings=image_embeddings, prompt=prompt if prompt is not None else "", num_inference_steps=num_inference_steps, timesteps=decoder_timesteps, guidance_scale=decoder_guidance_scale, negative_prompt=negative_prompt, generator=generator, output_type=output_type, return_dict=return_dict, callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, **kwargs, ) return outputs ================================================ FILE: src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from math import ceil from typing import Callable, Dict, List, Optional, Union import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer from ...loaders import StableDiffusionLoraLoaderMixin from ...schedulers import DDPMWuerstchenScheduler from ...utils import BaseOutput, deprecate, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .modeling_wuerstchen_prior import WuerstchenPrior logger = logging.get_logger(__name__) # pylint: disable=invalid-name DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import WuerstchenPriorPipeline >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 ... ).to("cuda") >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" >>> prior_output = pipe(prompt) ``` """ @dataclass class WuerstchenPriorPipelineOutput(BaseOutput): """ Output class for WuerstchenPriorPipeline. Args: image_embeddings (`torch.Tensor` or `np.ndarray`) Prior image embeddings for text prompt """ image_embeddings: Union[torch.Tensor, np.ndarray] class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): """ Pipeline for generating image prior for Wuerstchen. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) The pipeline also inherits the following loading methods: - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights Args: prior ([`Prior`]): The canonical unCLIP prior to approximate the image embedding from the text embedding. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`DDPMWuerstchenScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. latent_mean ('float', *optional*, defaults to 42.0): Mean value for latent diffusers. latent_std ('float', *optional*, defaults to 1.0): Standard value for latent diffusers. resolution_multiple ('float', *optional*, defaults to 42.67): Default resolution for multiple images generated. """ unet_name = "prior" text_encoder_name = "text_encoder" model_cpu_offload_seq = "text_encoder->prior" _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"] _lora_loadable_modules = ["prior", "text_encoder"] def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, prior: WuerstchenPrior, scheduler: DDPMWuerstchenScheduler, latent_mean: float = 42.0, latent_std: float = 1.0, resolution_multiple: float = 42.67, ) -> None: super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, prior=prior, scheduler=scheduler, ) self.register_to_config( latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def encode_prompt( self, device, num_images_per_prompt, do_classifier_free_guidance, prompt=None, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] attention_mask = attention_mask[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask.to(device) ) prompt_embeds = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) if negative_prompt_embeds is None and do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds_text_encoder_output = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) ) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.last_hidden_state if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # done duplicates return prompt_embeds, negative_prompt_embeds def check_inputs( self, prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, prompt_embeds=None, negative_prompt_embeds=None, ): if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if not isinstance(num_inference_steps, int): raise TypeError( f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ In Case you want to provide explicit timesteps, please use the 'timesteps' argument." ) @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def num_timesteps(self): return self._num_timesteps @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, height: int = 1024, width: int = 1024, num_inference_steps: int = 60, timesteps: List[float] = None, guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, output_type: Optional[str] = "pt", return_dict: bool = True, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to 1024): The height in pixels of the generated image. width (`int`, *optional*, defaults to 1024): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 60): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 8.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `decoder_guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.WuerstchenPriorPipelineOutput`] or `tuple` [`~pipelines.WuerstchenPriorPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image embeddings. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) # 0. Define commonly used variables device = self._execution_device self._guidance_scale = guidance_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # 1. Check inputs. Raise error if not correct if prompt is not None and not isinstance(prompt, list): if isinstance(prompt, str): prompt = [prompt] else: raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") if self.do_classifier_free_guidance: if negative_prompt is not None and not isinstance(negative_prompt, list): if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] else: raise TypeError( f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." ) self.check_inputs( prompt, negative_prompt, num_inference_steps, self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 2. Encode caption prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_encoder_hidden_states = ( torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds ) # 3. Determine latent shape of image embeddings dtype = text_encoder_hidden_states.dtype latent_height = ceil(height / self.config.resolution_multiple) latent_width = ceil(width / self.config.resolution_multiple) num_channels = self.prior.config.c_in effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width) # 4. Prepare and set timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latents latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler) # 6. Run denoising loop self._num_timesteps = len(timesteps[:-1]) for i, t in enumerate(self.progress_bar(timesteps[:-1])): ratio = t.expand(latents.size(0)).to(dtype) # 7. Denoise image embeddings predicted_image_embedding = self.prior( torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio, c=text_encoder_hidden_states, ) # 8. Check for classifier free guidance and apply it if self.do_classifier_free_guidance: predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) predicted_image_embedding = torch.lerp( predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale ) # 9. Renoise latents to next timestep latents = self.scheduler.step( model_output=predicted_image_embedding, timestep=ratio, sample=latents, generator=generator, ).prev_sample if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) text_encoder_hidden_states = callback_outputs.pop( "text_encoder_hidden_states", text_encoder_hidden_states ) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) # 10. Denormalize the latents latents = latents * self.config.latent_mean - self.config.latent_std # Offload all models self.maybe_free_model_hooks() if output_type == "np": latents = latents.cpu().float().numpy() if not return_dict: return (latents,) return WuerstchenPriorPipelineOutput(latents) ================================================ FILE: src/diffusers/py.typed ================================================ ================================================ FILE: src/diffusers/schedulers/README.md ================================================ # Schedulers For more information on the schedulers, please refer to the [docs](https://huggingface.co/docs/diffusers/api/schedulers/overview). ================================================ FILE: src/diffusers/schedulers/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import TYPE_CHECKING from ..utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_flax_available, is_scipy_available, is_torch_available, is_torchsde_available, ) _dummy_modules = {} _import_structure = {} try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_pt_objects # noqa F403 _dummy_modules.update(get_objects_from_module(dummy_pt_objects)) else: _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] _import_structure["scheduling_amused"] = ["AmusedScheduler"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] _import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"] _import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"] _import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"] _import_structure["scheduling_ddpm"] = ["DDPMScheduler"] _import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"] _import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"] _import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"] _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"] _import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"] _import_structure["scheduling_edm_dpmsolver_multistep"] = ["EDMDPMSolverMultistepScheduler"] _import_structure["scheduling_edm_euler"] = ["EDMEulerScheduler"] _import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"] _import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"] _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] _import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"] _import_structure["scheduling_lcm"] = ["LCMScheduler"] _import_structure["scheduling_pndm"] = ["PNDMScheduler"] _import_structure["scheduling_repaint"] = ["RePaintScheduler"] _import_structure["scheduling_sasolver"] = ["SASolverScheduler"] _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"] _import_structure["scheduling_tcd"] = ["TCDScheduler"] _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"] _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_flax_objects # noqa F403 _dummy_modules.update(get_objects_from_module(dummy_flax_objects)) else: _import_structure["scheduling_ddim_flax"] = ["FlaxDDIMScheduler"] _import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"] _import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"] _import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"] _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"] _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"] _import_structure["scheduling_sde_ve_flax"] = ["FlaxScoreSdeVeScheduler"] _import_structure["scheduling_utils_flax"] = [ "FlaxKarrasDiffusionSchedulers", "FlaxSchedulerMixin", "FlaxSchedulerOutput", "broadcast_to_shape_from_left", ] try: if not (is_torch_available() and is_scipy_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_torch_and_scipy_objects # noqa F403 _dummy_modules.update(get_objects_from_module(dummy_torch_and_scipy_objects)) else: _import_structure["scheduling_lms_discrete"] = ["LMSDiscreteScheduler"] try: if not (is_torch_available() and is_torchsde_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_torch_and_torchsde_objects # noqa F403 _dummy_modules.update(get_objects_from_module(dummy_torch_and_torchsde_objects)) else: _import_structure["scheduling_cosine_dpmsolver_multistep"] = ["CosineDPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_sde"] = ["DPMSolverSDEScheduler"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from ..utils import ( OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available, is_torchsde_available, ) try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler from .scheduling_amused import AmusedScheduler from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_parallel import DDIMParallelScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler from .scheduling_deis_multistep import DEISMultistepScheduler from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from .scheduling_edm_dpmsolver_multistep import EDMDPMSolverMultistepScheduler from .scheduling_edm_euler import EDMEulerScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler from .scheduling_lcm import LCMScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_repaint import RePaintScheduler from .scheduling_sasolver import SASolverScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_tcd import TCDScheduler from .scheduling_unclip import UnCLIPScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_flax_objects import * # noqa F403 else: from .scheduling_ddim_flax import FlaxDDIMScheduler from .scheduling_ddpm_flax import FlaxDDPMScheduler from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler from .scheduling_euler_discrete_flax import FlaxEulerDiscreteScheduler from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler from .scheduling_utils_flax import ( FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, ) try: if not (is_torch_available() and is_scipy_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 else: from .scheduling_lms_discrete import LMSDiscreteScheduler try: if not (is_torch_available() and is_torchsde_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 else: from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler else: import sys sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) for name, value in _dummy_modules.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/schedulers/deprecated/__init__.py ================================================ from typing import TYPE_CHECKING from ...utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_torch_available, is_transformers_available, ) _dummy_objects = {} _import_structure = {} try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_pt_objects # noqa F403 _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) else: _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"] _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_pt_objects import * # noqa F403 else: from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler else: import sys sys.modules[__name__] = _LazyModule( __name__, globals()["__file__"], _import_structure, module_spec=__spec__, ) for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) ================================================ FILE: src/diffusers/schedulers/deprecated/scheduling_karras_ve.py ================================================ # Copyright 2024 NVIDIA and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ...utils.torch_utils import randn_tensor from ..scheduling_utils import SchedulerMixin @dataclass class KarrasVeOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. derivative (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Derivative of predicted original image sample (x_0). pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor derivative: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None class KarrasVeScheduler(SchedulerMixin, ConfigMixin): """ A stochastic scheduler tailored to variance-expanding models. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. For more details on the parameters, see [Appendix E](https://arxiv.org/abs/2206.00364). The grid search values used to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in Table 5 of the paper. Args: sigma_min (`float`, defaults to 0.02): The minimum noise magnitude. sigma_max (`float`, defaults to 100): The maximum noise magnitude. s_noise (`float`, defaults to 1.007): The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 1.011]. s_churn (`float`, defaults to 80): The parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100]. s_min (`float`, defaults to 0.05): The start value of the sigma range to add noise (enable stochasticity). A reasonable range is [0, 10]. s_max (`float`, defaults to 50): The end value of the sigma range to add noise. A reasonable range is [0.2, 80]. """ order = 2 @register_to_config def __init__( self, sigma_min: float = 0.02, sigma_max: float = 100, s_noise: float = 1.007, s_churn: float = 80, s_min: float = 0.05, s_max: float = 50, ): # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max # setable values self.num_inference_steps: int = None self.timesteps: np.IntTensor = None self.schedule: torch.Tensor = None # sigma(t_i) def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = torch.from_numpy(timesteps).to(device) schedule = [ ( self.config.sigma_max**2 * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) ) for i in self.timesteps ] self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device) def add_noise_to_input( self, sample: torch.Tensor, sigma: float, generator: Optional[torch.Generator] = None ) -> Tuple[torch.Tensor, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a `gamma_i ≥ 0` to reach a higher noise level `sigma_hat = sigma_i + gamma_i*sigma_i`. Args: sample (`torch.Tensor`): The input sample. sigma (`float`): generator (`torch.Generator`, *optional*): A random number generator. """ if self.config.s_min <= sigma <= self.config.s_max: gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) else: gamma = 0 # sample eps ~ N(0, S_noise^2 * I) eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device) sigma_hat = sigma + gamma * sigma sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) return sample_hat, sigma_hat def step( self, model_output: torch.Tensor, sigma_hat: float, sigma_prev: float, sample_hat: torch.Tensor, return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. sigma_hat (`float`): sigma_prev (`float`): sample_hat (`torch.Tensor`): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ pred_original_sample = sample_hat + sigma_hat * model_output derivative = (sample_hat - pred_original_sample) / sigma_hat sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative if not return_dict: return (sample_prev, derivative) return KarrasVeOutput( prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample ) def step_correct( self, model_output: torch.Tensor, sigma_hat: float, sigma_prev: float, sample_hat: torch.Tensor, sample_prev: torch.Tensor, derivative: torch.Tensor, return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: """ Corrects the predicted sample based on the `model_output` of the network. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.Tensor`): TODO sample_prev (`torch.Tensor`): TODO derivative (`torch.Tensor`): TODO return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO """ pred_original_sample = sample_prev + sigma_prev * model_output derivative_corr = (sample_prev - pred_original_sample) / sigma_prev sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) if not return_dict: return (sample_prev, derivative) return KarrasVeOutput( prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample ) def add_noise(self, original_samples, noise, timesteps): raise NotImplementedError() ================================================ FILE: src/diffusers/schedulers/deprecated/scheduling_sde_vp.py ================================================ # Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch import math from typing import Union import torch from ...configuration_utils import ConfigMixin, register_to_config from ...utils.torch_utils import randn_tensor from ..scheduling_utils import SchedulerMixin class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): """ `ScoreSdeVpScheduler` is a variance preserving stochastic differential equation (SDE) scheduler. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 2000): The number of diffusion steps to train the model. beta_min (`int`, defaults to 0.1): beta_max (`int`, defaults to 20): sampling_eps (`int`, defaults to 1e-3): The end value of sampling where timesteps decrease progressively from 1 to epsilon. """ order = 1 @register_to_config def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): self.sigmas = None self.discrete_sigmas = None self.timesteps = None def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None): """ Sets the continuous timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device) def step_pred(self, score, x, t, generator=None): """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: score (): x (): t (): generator (`torch.Generator`, *optional*): A random number generator. """ if self.timesteps is None: raise ValueError( "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) # TODO(Patrick) better comments + non-PyTorch # postprocess model score log_mean_coeff = -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) std = std.flatten() while len(std.shape) < len(score.shape): std = std.unsqueeze(-1) score = -score / std # compute dt = -1.0 / len(self.timesteps) beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) beta_t = beta_t.flatten() while len(beta_t.shape) < len(x.shape): beta_t = beta_t.unsqueeze(-1) drift = -0.5 * beta_t * x diffusion = torch.sqrt(beta_t) drift = drift - diffusion**2 * score x_mean = x + drift * dt # add noise noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype) x = x_mean + diffusion * math.sqrt(-dt) * noise return x, x_mean def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_amused.py ================================================ import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .scheduling_utils import SchedulerMixin def gumbel_noise(t, generator=None): device = generator.device if generator is not None else t.device noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device) return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20)) def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None): confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator) sorted_confidence = torch.sort(confidence, dim=-1).values cut_off = torch.gather(sorted_confidence, 1, mask_len.long()) masking = confidence < cut_off return masking @dataclass class AmusedSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: torch.Tensor = None class AmusedScheduler(SchedulerMixin, ConfigMixin): order = 1 temperatures: torch.Tensor @register_to_config def __init__( self, mask_token_id: int, masking_schedule: str = "cosine", ): self.temperatures = None self.timesteps = None def set_timesteps( self, num_inference_steps: int, temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), device: Union[str, torch.device] = None, ): self.timesteps = torch.arange(num_inference_steps, device=device).flip(0) if isinstance(temperature, (tuple, list)): self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device) else: self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device) def step( self, model_output: torch.Tensor, timestep: torch.long, sample: torch.LongTensor, starting_mask_ratio: int = 1, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[AmusedSchedulerOutput, Tuple]: two_dim_input = sample.ndim == 3 and model_output.ndim == 4 if two_dim_input: batch_size, codebook_size, height, width = model_output.shape sample = sample.reshape(batch_size, height * width) model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1) unknown_map = sample == self.config.mask_token_id probs = model_output.softmax(dim=-1) device = probs.device probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU if probs_.device.type == "cpu" and probs_.dtype != torch.float32: probs_ = probs_.float() # multinomial is not implemented for cpu half precision probs_ = probs_.reshape(-1, probs.size(-1)) pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device) pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1]) pred_original_sample = torch.where(unknown_map, pred_original_sample, sample) if timestep == 0: prev_sample = pred_original_sample else: seq_len = sample.shape[1] step_idx = (self.timesteps == timestep).nonzero() ratio = (step_idx + 1) / len(self.timesteps) if self.config.masking_schedule == "cosine": mask_ratio = torch.cos(ratio * math.pi / 2) elif self.config.masking_schedule == "linear": mask_ratio = 1 - ratio else: raise ValueError(f"unknown masking schedule {self.config.masking_schedule}") mask_ratio = starting_mask_ratio * mask_ratio mask_len = (seq_len * mask_ratio).floor() # do not mask more than amount previously masked mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) # mask at least one mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len) selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0] # Ignores the tokens given in the input by overwriting their confidence. selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator) # Masks tokens with lower confidence. prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample) if two_dim_input: prev_sample = prev_sample.reshape(batch_size, height, width) pred_original_sample = pred_original_sample.reshape(batch_size, height, width) if not return_dict: return (prev_sample, pred_original_sample) return AmusedSchedulerOutput(prev_sample, pred_original_sample) def add_noise(self, sample, timesteps, generator=None): step_idx = (self.timesteps == timesteps).nonzero() ratio = (step_idx + 1) / len(self.timesteps) if self.config.masking_schedule == "cosine": mask_ratio = torch.cos(ratio * math.pi / 2) elif self.config.masking_schedule == "linear": mask_ratio = 1 - ratio else: raise ValueError(f"unknown masking schedule {self.config.masking_schedule}") mask_indices = ( torch.rand( sample.shape, device=generator.device if generator is not None else sample.device, generator=generator ).to(sample.device) < mask_ratio ) masked_sample = sample.clone() masked_sample[mask_indices] = self.config.mask_token_id return masked_sample ================================================ FILE: src/diffusers/schedulers/scheduling_consistency_decoder.py ================================================ import math from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor from .scheduling_utils import SchedulerMixin # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) @dataclass class ConsistencyDecoderSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.Tensor class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin): order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1024, sigma_data: float = 0.5, ): betas = betas_for_alpha_bar(num_train_timesteps) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) sigmas = torch.sqrt(1.0 / alphas_cumprod - 1) sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2) self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5 self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5 def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, ): if num_inference_steps != 2: raise ValueError("Currently more than 2 inference steps are not supported.") self.timesteps = torch.tensor([1008, 512], dtype=torch.long, device=device) self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device) self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device) self.c_skip = self.c_skip.to(device) self.c_out = self.c_out.to(device) self.c_in = self.c_in.to(device) @property def init_noise_sigma(self): return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]] def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample * self.c_in[timestep] def step( self, model_output: torch.Tensor, timestep: Union[float, torch.Tensor], sample: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. timestep (`float`): The current timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample timestep_idx = torch.where(self.timesteps == timestep)[0] if timestep_idx == len(self.timesteps) - 1: prev_sample = x_0 else: noise = randn_tensor(x_0.shape, generator=generator, dtype=x_0.dtype, device=x_0.device) prev_sample = ( self.sqrt_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * x_0 + self.sqrt_one_minus_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * noise ) if not return_dict: return (prev_sample,) return ConsistencyDecoderSchedulerOutput(prev_sample=prev_sample) ================================================ FILE: src/diffusers/schedulers/scheduling_consistency_models.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class CMStochasticIterativeSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.Tensor class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): """ Multistep and onestep sampling for consistency models. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 40): The number of diffusion steps to train the model. sigma_min (`float`, defaults to 0.002): Minimum noise magnitude in the sigma schedule. Defaults to 0.002 from the original implementation. sigma_max (`float`, defaults to 80.0): Maximum noise magnitude in the sigma schedule. Defaults to 80.0 from the original implementation. sigma_data (`float`, defaults to 0.5): The standard deviation of the data distribution from the EDM [paper](https://huggingface.co/papers/2206.00364). Defaults to 0.5 from the original implementation. s_noise (`float`, defaults to 1.0): The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 1.011]. Defaults to 1.0 from the original implementation. rho (`float`, defaults to 7.0): The parameter for calculating the Karras sigma schedule from the EDM [paper](https://huggingface.co/papers/2206.00364). Defaults to 7.0 from the original implementation. clip_denoised (`bool`, defaults to `True`): Whether to clip the denoised outputs to `(-1, 1)`. timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*): An explicit timestep schedule that can be optionally specified. The timesteps are expected to be in increasing order. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 40, sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, s_noise: float = 1.0, rho: float = 7.0, clip_denoised: bool = True, ): # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max ramp = np.linspace(0, 1, num_train_timesteps) sigmas = self._convert_to_karras(ramp) timesteps = self.sigma_to_t(sigmas) # setable values self.num_inference_steps = None self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps) self.custom_timesteps = False self.is_scale_input_called = False self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`. Args: sample (`torch.Tensor`): The input sample. timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ # Get sigma corresponding to timestep if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5) self.is_scale_input_called = True return sample def sigma_to_t(self, sigmas: Union[float, np.ndarray]): """ Gets scaled timesteps from the Karras sigmas for input to the consistency model. Args: sigmas (`float` or `np.ndarray`): A single Karras sigma or an array of Karras sigmas. Returns: `float` or `np.ndarray`: A scaled input timestep or scaled input timestep array. """ if not isinstance(sigmas, np.ndarray): sigmas = np.array(sigmas, dtype=np.float64) timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44) return timesteps def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): """ Sets the timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed, `num_inference_steps` must be `None`. """ if num_inference_steps is None and timesteps is None: raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") # Follow DDPMScheduler custom timesteps logic if timesteps is not None: for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`timesteps` must be in descending order.") if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) self.custom_timesteps = True else: if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False # Map timesteps to Karras sigmas directly for multistep sampling # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 num_train_timesteps = self.config.num_train_timesteps ramp = timesteps[::-1].copy() ramp = ramp / (num_train_timesteps - 1) sigmas = self._convert_to_karras(ramp) timesteps = self.sigma_to_t(sigmas) sigmas = np.concatenate([sigmas, [self.config.sigma_min]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: self.timesteps = torch.from_numpy(timesteps).to(device=device) self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Modified _convert_to_karras implementation that takes in ramp as argument def _convert_to_karras(self, ramp): """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = self.config.sigma_min sigma_max: float = self.config.sigma_max rho = self.config.rho min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def get_scalings(self, sigma): sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out def get_scalings_for_boundary_condition(self, sigma): """ Gets the scalings used in the consistency model parameterization (from Appendix C of the [paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition. `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`. Args: sigma (`torch.Tensor`): The current sigma in the Karras sigma schedule. Returns: `tuple`: A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out` (which weights the consistency model output) is the second element. """ sigma_min = self.config.sigma_min sigma_data = self.config.sigma_data c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: Union[float, torch.Tensor], sample: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. timestep (`float`): The current timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" f" `{self.__class__}.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) sigma_min = self.config.sigma_min sigma_max = self.config.sigma_max if self.step_index is None: self._init_step_index(timestep) # sigma_next corresponds to next_t in original implementation sigma = self.sigmas[self.step_index] if self.step_index + 1 < self.config.num_train_timesteps: sigma_next = self.sigmas[self.step_index + 1] else: # Set sigma_next to sigma_min sigma_next = self.sigmas[-1] # Get scalings for boundary conditions c_skip, c_out = self.get_scalings_for_boundary_condition(sigma) # 1. Denoise model output using boundary conditions denoised = c_out * model_output + c_skip * sample if self.config.clip_denoised: denoised = denoised.clamp(-1, 1) # 2. Sample z ~ N(0, s_noise^2 * I) # Noise is not used for onestep sampling. if len(self.timesteps) > 1: noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) else: noise = torch.zeros_like(model_output) z = noise * self.config.s_noise sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max) # 3. Return noisy sample # tau = sigma_hat, eps = sigma_min prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5 # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py ================================================ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler from .scheduling_utils import SchedulerMixin, SchedulerOutput class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021). This scheduler was used in Stable Audio Open [1]. [1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358 This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: sigma_min (`float`, *optional*, defaults to 0.3): Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1]. sigma_max (`float`, *optional*, defaults to 500): Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1]. sigma_data (`float`, *optional*, defaults to 1.0): The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1]. sigma_schedule (`str`, *optional*, defaults to `exponential`): Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl. num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. solver_order (`int`, defaults to 2): The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`. prediction_type (`str`, defaults to `v_prediction`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. euler_at_final (`bool`, defaults to `False`): Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference steps, but sometimes may result in blurring. final_sigmas_type (`str`, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. """ _compatibles = [] order = 1 @register_to_config def __init__( self, sigma_min: float = 0.3, sigma_max: float = 500, sigma_data: float = 1.0, sigma_schedule: str = "exponential", num_train_timesteps: int = 1000, solver_order: int = 2, prediction_type: str = "v_prediction", rho: float = 7.0, solver_type: str = "midpoint", lower_order_final: bool = True, euler_at_final: bool = False, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" ): if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") ramp = torch.linspace(0, 1, num_train_timesteps) if sigma_schedule == "karras": sigmas = self._compute_karras_sigmas(ramp) elif sigma_schedule == "exponential": sigmas = self._compute_exponential_sigmas(ramp) self.timesteps = self.precondition_noise(sigmas) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # setable values self.num_inference_steps = None self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def init_noise_sigma(self): # standard deviation of the initial noise distribution return (self.config.sigma_max**2 + 1) ** 0.5 @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs def precondition_inputs(self, sample, sigma): c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) scaled_sample = sample * c_in return scaled_sample def precondition_noise(self, sigma): if not isinstance(sigma, torch.Tensor): sigma = torch.tensor([sigma]) return sigma.atan() / math.pi * 2 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs def precondition_outputs(self, sample, model_output, sigma): sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) if self.config.prediction_type == "epsilon": c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 elif self.config.prediction_type == "v_prediction": c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 else: raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") denoised = c_skip * sample + c_out * model_output return denoised # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sample = self.precondition_inputs(sample, sigma) self.is_scale_input_called = True return sample def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps ramp = torch.linspace(0, 1, self.num_inference_steps) if self.config.sigma_schedule == "karras": sigmas = self._compute_karras_sigmas(ramp) elif self.config.sigma_schedule == "exponential": sigmas = self._compute_exponential_sigmas(ramp) sigmas = sigmas.to(dtype=torch.float32, device=device) self.timesteps = self.precondition_noise(sigmas) if self.config.final_sigmas_type == "sigma_min": sigma_last = self.config.sigma_min elif self.config.final_sigmas_type == "zero": sigma_last = 0 else: raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)]) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 # add an index counter for schedulers that allow duplicated timesteps self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # if a noise sampler is used, reinitialise it self.noise_sampler = None # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max rho = self.config.rho min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: """Implementation closely follows k-diffusion. https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) return sigmas # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1 sigma_t = sigma return alpha_t, sigma_t def convert_model_output( self, model_output: torch.Tensor, sample: torch.Tensor = None, ) -> torch.Tensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise prediction and data prediction models. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The converted model output. """ sigma = self.sigmas[self.step_index] x0_pred = self.precondition_outputs(sample, model_output, sigma) return x0_pred def dpm_solver_first_order_update( self, model_output: torch.Tensor, sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ One step for the first-order DPMSolver (equivalent to DDIM). Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) return x_t def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.Tensor], sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ One step for the second-order multistep DPMSolver. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) m0, m1 = model_output_list[-1], model_output_list[-2] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) # sde-dpmsolver++ assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) elif len(index_candidates) > 1: step_index = index_candidates[1].item() else: step_index = index_candidates[0].item() return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the multistep DPMSolver. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) or self.config.final_sigmas_type == "zero" ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.noise_sampler is None: seed = None if generator is not None: seed = ( [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed() ) self.noise_sampler = BrownianTreeNoiseSampler( model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed ) noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to( model_output.device ) if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_ddim.py ================================================ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DDIMScheduler(SchedulerMixin, ConfigMixin): """ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. clip_sample (`bool`, defaults to `True`): Clip the predicted sample for numerical stability. clip_sample_range (`float`, defaults to 1.0): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, defaults to `True`): Each diffusion step uses the alphas product value at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the alpha value at step 0. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. eta (`float`): The weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`, defaults to `False`): If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` has no effect. generator (`torch.Generator`, *optional*): A random number generator. variance_noise (`torch.Tensor`): Alternative to generating noise with `generator` by directly providing the noise for the variance itself. Useful for methods such as [`CycleDiffusion`]. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) variance = std_dev_t * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_ddim_cogvideox.py ================================================ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) def rescale_zero_terminal_snr(alphas_cumprod): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt return alphas_bar class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): """ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. clip_sample (`bool`, defaults to `True`): Clip the predicted sample for numerical stability. clip_sample_range (`float`, defaults to 1.0): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, defaults to `True`): Each diffusion step uses the alphas product value at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the alpha value at step 0. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, beta_end: float = 0.0120, beta_schedule: str = "scaled_linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, snr_shift_scale: float = 3.0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Modify: SNR shift following SD3 self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) # Rescale for zero SNR if rescale_betas_zero_snr: self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. eta (`float`): The weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`, defaults to `False`): If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` has no effect. generator (`torch.Generator`, *optional*): A random number generator. variance_noise (`torch.Tensor`): Alternative to generating noise with `generator` by directly providing the noise for the variance itself. Useful for methods such as [`CycleDiffusion`]. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # To make style tests pass, commented out `pred_epsilon` as it is an unused variable if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) # pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t prev_sample = a_t * sample + b_t * pred_original_sample if not return_dict: return (prev_sample,) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_ddim_flax.py ================================================ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, get_velocity_common, ) @flax.struct.dataclass class DDIMSchedulerState: common: CommonSchedulerState final_alpha_cumprod: jnp.ndarray # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod def create( cls, common: CommonSchedulerState, final_alpha_cumprod: jnp.ndarray, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, ): return cls( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) @dataclass class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput): state: DDIMSchedulerState class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): option to clip predicted sample between for numerical stability. The clip range is determined by `clip_sample_range`. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): An offset added to the inference steps, as required by some model families. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. `v-prediction` is not supported for this scheduler. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, clip_sample: bool = True, clip_sample_range: float = 1.0, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: if common is None: common = CommonSchedulerState.create(self) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. final_alpha_cumprod = ( jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] ) # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return DDIMSchedulerState.create( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def scale_model_input( self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def set_timesteps( self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, ) def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep): alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def step( self, state: DDIMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, eta: float = 0.0, return_dict: bool = True, ) -> Union[FlaxDDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class Returns: [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps alphas_cumprod = state.common.alphas_cumprod final_alpha_cumprod = state.final_alpha_cumprod # 2. compute alphas, betas alpha_prod_t = alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.clip_sample: pred_original_sample = pred_original_sample.clip( -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(state, timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if not return_dict: return (prev_sample, state) return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, state: DDIMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def get_velocity( self, state: DDIMSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return get_velocity_common(state.common, sample, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_ddim_inverse.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import BaseOutput, deprecate @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): """ `DDIMInverseScheduler` is the reverse scheduler of [`DDIMScheduler`]. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. clip_sample (`bool`, defaults to `True`): Clip the predicted sample for numerical stability. clip_sample_range (`float`, defaults to 1.0): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, defaults to `True`): Each diffusion step uses the alphas product value at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to 0, otherwise it uses the alpha value at step `num_train_timesteps - 1`. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ order = 1 ignore_for_config = ["kwargs"] _deprecated_kwargs = ["set_alpha_to_zero"] @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", clip_sample_range: float = 1.0, timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, **kwargs, ): if kwargs.get("set_alpha_to_zero", None) is not None: deprecation_message = ( "The `set_alpha_to_zero` argument is deprecated. Please use `set_alpha_to_one` instead." ) deprecate("set_alpha_to_zero", "1.0.0", deprecation_message, standard_warn=False) set_alpha_to_one = kwargs["set_alpha_to_zero"] if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in inverted ddim, we are looking into the next alphas_cumprod # For the initial step, there is no current alphas_cumprod, and the index is out of bounds # `set_alpha_to_one` decides whether we set this parameter simply to one # in this case, self.step() just output the predicted noise # or whether we use the initial alpha used in training the diffusion model. self.initial_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)[::-1]).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. eta (`float`): The weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`, defaults to `False`): If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` has no effect. variance_noise (`torch.Tensor`): Alternative to generating noise with `generator` by directly providing the noise for the variance itself. Useful for methods such as [`CycleDiffusion`]. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ # 1. get previous step value (=t+1) prev_timestep = timestep timestep = min( timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1 ) # 2. compute alphas, betas # change original implementation to exactly match noise levels for analogous forward process alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if not return_dict: return (prev_sample, pred_original_sample) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_ddim_parallel.py ================================================ # Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput class DDIMParallelSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): An offset added to the inference steps, as required by some model families. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. rescale_betas_zero_snr (`bool`, default `False`): whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). This can enable the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 _is_ode_scheduler = True @register_to_config # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.__init__ def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample def _get_variance(self, timestep, prev_timestep=None): if prev_timestep is None: prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def _batch_get_variance(self, t, prev_t): alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMParallelSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.Tensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. generator: random number generator. variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://arxiv.org/abs/2210.05559) return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) variance = std_dev_t * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def batch_step_no_noise( self, model_output: torch.Tensor, timesteps: List[int], sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, ) -> torch.Tensor: """ Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise is pre-sampled by the pipeline. Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): direct output from learned diffusion model. timesteps (`List[int]`): current discrete timesteps in the diffusion chain. This is now a list of integers. sample (`torch.Tensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. Returns: `torch.Tensor`: sample tensor at previous timestep. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) assert eta == 0.0 # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) t = timesteps prev_t = t - self.config.num_train_timesteps // self.num_inference_steps t = t.view(-1, *([1] * (model_output.ndim - 1))) prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) # 1. compute alphas, betas self.alphas_cumprod = self.alphas_cumprod.to(model_output.device) self.final_alpha_cumprod = self.final_alpha_cumprod.to(model_output.device) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._batch_get_variance(t, prev_t).to(model_output.device).view(*alpha_prod_t_prev.shape) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction return prev_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_ddpm.py ================================================ # Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass class DDPMSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DDPMScheduler(SchedulerMixin, ConfigMixin): """ `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): An array of betas to pass directly to the constructor without using `beta_start` and `beta_end`. variance_type (`str`, defaults to `"fixed_small"`): Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, defaults to `True`): Clip the predicted sample for numerical stability. clip_sample_range (`float`, defaults to 1.0): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) elif beta_schedule == "sigmoid": # GeoDiff sigmoid schedule betas = torch.linspace(-6, 6, num_train_timesteps) self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.custom_timesteps = False self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed, `num_inference_steps` must be `None`. """ if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") if timesteps is not None: for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`custom_timesteps` must be in descending order.") if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) self.custom_timesteps = True else: if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps self.custom_timesteps = False # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): prev_t = self.previous_timestep(t) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t # we always take the log of variance, so clamp it to ensure it's not 0 variance = torch.clamp(variance, min=1e-20) if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small": variance = variance # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = torch.log(variance) variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = current_beta_t elif variance_type == "fixed_large_log": # Glide max_log variance = torch.log(current_beta_t) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = torch.log(variance) max_log = torch.log(current_beta_t) frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator=None, return_dict: bool = True, ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ t = timestep prev_t = self.previous_timestep(t) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) # 3. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise variance = 0 if t > 0: device = model_output.device variance_noise = randn_tensor( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise elif self.variance_type == "learned_range": variance = self._get_variance(t, predicted_variance=predicted_variance) variance = torch.exp(0.5 * variance) * variance_noise else: variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample,) return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps def previous_timestep(self, timestep): if self.custom_timesteps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: num_inference_steps = ( self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps ) prev_t = timestep - self.config.num_train_timesteps // num_inference_steps return prev_t ================================================ FILE: src/diffusers/schedulers/scheduling_ddpm_flax.py ================================================ # Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, get_velocity_common, ) @flax.struct.dataclass class DDPMSchedulerState: common: CommonSchedulerState # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray): return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps) @dataclass class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput): state: DDPMSchedulerState class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. `v-prediction` is not supported for this scheduler. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: if common is None: common = CommonSchedulerState.create(self) # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return DDPMSchedulerState.create( common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def scale_model_input( self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def set_timesteps( self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> DDPMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`DDIMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, ) def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, variance_type=None): alpha_prod_t = state.common.alphas_cumprod[t] alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * state.common.betas[t] if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small": variance = jnp.clip(variance, a_min=1e-20) # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = jnp.log(jnp.clip(variance, a_min=1e-20)) elif variance_type == "fixed_large": variance = state.common.betas[t] elif variance_type == "fixed_large_log": # Glide max_log variance = jnp.log(state.common.betas[t]) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = variance max_log = state.common.betas[t] frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance def step( self, state: DDPMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, key: Optional[jax.Array] = None, return_dict: bool = True, ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. key (`jax.Array`): a PRNG key. return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class Returns: [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep if key is None: key = jax.random.key(0) if ( len(model_output.shape) > 1 and model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"] ): model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = state.common.alphas_cumprod[t] alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " " for the FlaxDDPMScheduler." ) # 3. Clip "predicted x_0" if self.config.clip_sample: pred_original_sample = jnp.clip(pred_original_sample, -1, 1) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * state.common.betas[t]) / beta_prod_t current_sample_coeff = state.common.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise def random_variance(): split_key = jax.random.split(key, num=1)[0] noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise variance = jnp.where(t > 0, random_variance(), jnp.zeros(model_output.shape, dtype=self.dtype)) pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample, state) return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state) def add_noise( self, state: DDPMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def get_velocity( self, state: DDPMSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return get_velocity_common(state.common, sample, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_ddpm_parallel.py ================================================ # Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput class DDPMParallelSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): An offset added to the inference steps, as required by some model families. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 _is_ode_scheduler = False @register_to_config # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.__init__ def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) elif beta_schedule == "sigmoid": # GeoDiff sigmoid schedule betas = torch.linspace(-6, 6, num_train_timesteps) self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.custom_timesteps = False self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.set_timesteps def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed, `num_inference_steps` must be `None`. """ if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") if timesteps is not None: for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`custom_timesteps` must be in descending order.") if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) self.custom_timesteps = True else: if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps self.custom_timesteps = False # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance def _get_variance(self, t, predicted_variance=None, variance_type=None): prev_t = self.previous_timestep(t) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t # we always take the log of variance, so clamp it to ensure it's not 0 variance = torch.clamp(variance, min=1e-20) if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small": variance = variance # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = torch.log(variance) variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = current_beta_t elif variance_type == "fixed_large_log": # Glide max_log variance = torch.log(current_beta_t) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = torch.log(variance) max_log = torch.log(current_beta_t) frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator=None, return_dict: bool = True, ) -> Union[DDPMParallelSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.Tensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep prev_t = self.previous_timestep(t) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) # 3. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise variance = 0 if t > 0: device = model_output.device variance_noise = randn_tensor( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise elif self.variance_type == "learned_range": variance = self._get_variance(t, predicted_variance=predicted_variance) variance = torch.exp(0.5 * variance) * variance_noise else: variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample,) return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def batch_step_no_noise( self, model_output: torch.Tensor, timesteps: List[int], sample: torch.Tensor, ) -> torch.Tensor: """ Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise is pre-sampled by the pipeline. Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): direct output from learned diffusion model. timesteps (`List[int]`): current discrete timesteps in the diffusion chain. This is now a list of integers. sample (`torch.Tensor`): current instance of sample being created by diffusion process. Returns: `torch.Tensor`: sample tensor at previous timestep. """ t = timesteps num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps prev_t = t - self.config.num_train_timesteps // num_inference_steps t = t.view(-1, *([1] * (model_output.ndim - 1))) prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: pass # 1. compute alphas, betas self.alphas_cumprod = self.alphas_cumprod.to(model_output.device) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMParallelScheduler." ) # 3. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample return pred_prev_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): if self.custom_timesteps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: num_inference_steps = ( self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps ) prev_t = timestep - self.config.num_train_timesteps // num_inference_steps return prev_t ================================================ FILE: src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py ================================================ # Copyright (c) 2022 Pablo Pernías MIT License # Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor from .scheduling_utils import SchedulerMixin @dataclass class DDPMWuerstchenSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.Tensor def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 Args: scaler (`float`): .... s (`float`): .... """ @register_to_config def __init__( self, scaler: float = 1.0, s: float = 0.008, ): self.scaler = scaler self.s = torch.tensor([s]) self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 def _alpha_cumprod(self, t, device): if self.scaler > 1: t = 1 - (1 - t) ** self.scaler elif self.scaler < 1: t = t**self.scaler alpha_cumprod = torch.cos( (t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5 ) ** 2 / self._init_alpha_cumprod.to(device) return alpha_cumprod.clamp(0.0001, 0.9999) def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.Tensor`: scaled input sample """ return sample def set_timesteps( self, num_inference_steps: int = None, timesteps: Optional[List[int]] = None, device: Union[str, torch.device] = None, ): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`Dict[float, int]`): the number of diffusion steps used when generating samples with a pre-trained model. If passed, then `timesteps` must be `None`. device (`str` or `torch.device`, optional): the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10} """ if timesteps is None: timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device) if not isinstance(timesteps, torch.Tensor): timesteps = torch.Tensor(timesteps).to(device) self.timesteps = timesteps def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator=None, return_dict: bool = True, ) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.Tensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMWuerstchenSchedulerOutput class Returns: [`DDPMWuerstchenSchedulerOutput`] or `tuple`: [`DDPMWuerstchenSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ dtype = model_output.dtype device = model_output.device t = timestep prev_t = self.previous_timestep(t) alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]]) alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) alpha = alpha_cumprod / alpha_cumprod_prev mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt()) std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype) std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]]) if not return_dict: return (pred.to(dtype),) return DDPMWuerstchenSchedulerOutput(prev_sample=pred.to(dtype)) def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: device = original_samples.device dtype = original_samples.dtype alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view( timesteps.size(0), *[1 for _ in original_samples.shape[1:]] ) noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise return noisy_samples.to(dtype=dtype) def __len__(self): return self.config.num_train_timesteps def previous_timestep(self, timestep): index = (self.timesteps - timestep[0]).abs().argmin().item() prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0]) return prev_t ================================================ FILE: src/diffusers/schedulers/scheduling_deis_multistep.py ================================================ # Copyright 2024 FLAIR Lab and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): """ `DEISMultistepScheduler` is a fast high order solver for diffusion ordinary differential equations (ODEs). This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. solver_order (`int`, defaults to 2): The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, defaults to `epsilon`): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. algorithm_type (`str`, defaults to `deis`): The algorithm type for the solver. lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "deis", solver_type: str = "logrho", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DEIS if algorithm_type not in ["deis"]: if algorithm_type in ["dpmsolver", "dpmsolver++"]: self.register_to_config(algorithm_type="deis") else: raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") if solver_type not in ["logrho"]: if solver_type in ["midpoint", "heun", "bh1", "bh2"]: self.register_to_config(solver_type="logrho") else: raise NotImplementedError(f"solver type {solver_type} is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 # add an index counter for schedulers that allow duplicated timesteps self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = 1 / ((sigma**2 + 1) ** 0.5) sigma_t = sigma * alpha_t return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def convert_model_output( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """ Convert the model output to the corresponding type the DEIS algorithm needs. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The converted model output. """ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError("missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) if self.config.prediction_type == "epsilon": x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DEISMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) if self.config.algorithm_type == "deis": return (sample - alpha_t * x0_pred) / sigma_t else: raise NotImplementedError("only support log-rho multistep deis now") def deis_first_order_update( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """ One step for the first-order DEIS (equivalent to DDIM). Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. prev_timestep (`int`): The previous discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s if self.config.algorithm_type == "deis": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output else: raise NotImplementedError("only support log-rho multistep deis now") return x_t def multistep_deis_second_order_update( self, model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """ One step for the second-order multistep DEIS. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing `sample` as a required keyward argument") if timestep_list is not None: deprecate( "timestep_list", "1.0.0", "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) m0, m1 = model_output_list[-1], model_output_list[-2] rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 if self.config.algorithm_type == "deis": def ind_fn(t, b, c): # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}] return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c)) coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1) coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0) x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1) return x_t else: raise NotImplementedError("only support log-rho multistep deis now") def multistep_deis_third_order_update( self, model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """ One step for the third-order multistep DEIS. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. sample (`torch.Tensor`): A current instance of a sample created by diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing`sample` as a required keyward argument") if timestep_list is not None: deprecate( "timestep_list", "1.0.0", "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], self.sigmas[self.step_index - 2], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] rho_t, rho_s0, rho_s1, rho_s2 = ( sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1, sigma_s2 / alpha_s2, ) if self.config.algorithm_type == "deis": def ind_fn(t, b, c, d): # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}] numerator = t * ( np.log(c) * (np.log(d) - np.log(t) + 1) - np.log(d) * np.log(t) + np.log(d) + np.log(t) ** 2 - 2 * np.log(t) + 2 ) denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d)) return numerator / denominator coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2) coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0) coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1) x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2) return x_t else: raise NotImplementedError("only support log-rho multistep deis now") # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) elif len(index_candidates) > 1: step_index = index_candidates[1].item() else: step_index = index_candidates[0].item() return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the multistep DEIS. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) lower_order_final = ( (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.deis_first_order_update(model_output, sample=sample) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample) else: prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. Returns: `torch.Tensor`: A scaled input sample. """ return sample # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_dpm_cogvideox.py ================================================ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) def rescale_zero_terminal_snr(alphas_cumprod): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt return alphas_bar class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): """ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. clip_sample (`bool`, defaults to `True`): Clip the predicted sample for numerical stability. clip_sample_range (`float`, defaults to 1.0): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, defaults to `True`): Each diffusion step uses the alphas product value at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the alpha value at step 0. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, beta_end: float = 0.0120, beta_schedule: str = "scaled_linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, snr_shift_scale: float = 3.0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Modify: SNR shift following SD3 self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) # Rescale for zero SNR if rescale_betas_zero_snr: self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None): lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log() lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log() h = lamb_next - lamb if alpha_prod_t_back is not None: lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log() h_last = lamb - lamb_previous r = h_last / h return h, r, lamb, lamb_next else: return h, None, lamb, lamb_next def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back): mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp() mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5 if alpha_prod_t_back is not None: mult3 = 1 + 1 / (2 * r) mult4 = 1 / (2 * r) return mult1, mult2, mult3, mult4 else: return mult1, mult2 def step( self, model_output: torch.Tensor, old_pred_original_sample: torch.Tensor, timestep: int, timestep_back: int, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = False, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. eta (`float`): The weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`, defaults to `False`): If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` has no effect. generator (`torch.Generator`, *optional*): A random number generator. variance_noise (`torch.Tensor`): Alternative to generating noise with `generator` by directly providing the noise for the variance itself. Useful for methods such as [`CycleDiffusion`]. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # To make style tests pass, commented out `pred_epsilon` as it is an unused variable if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) # pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)) mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5 noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * noise if old_pred_original_sample is None or prev_timestep < 0: # Save a network evaluation if all noise levels are 0 or on the first step return prev_sample, pred_original_sample else: denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise prev_sample = x_advanced if not return_dict: return (prev_sample, pred_original_sample) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_dpmsolver_multistep.py ================================================ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. solver_order (`int`, defaults to 2): The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. euler_at_final (`bool`, defaults to `False`): Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference steps, but sometimes may result in blurring. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. use_lu_lambdas (`bool`, *optional*, defaults to `False`): Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of `lambda(t)`. final_sigmas_type (`str`, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. variance_type (`str`, *optional*): Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output contains the predicted Gaussian variance. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, use_lu_lambdas: Optional[bool] = False, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, ): if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) if rescale_betas_zero_snr: # Close to 0 without being 0 so first sigma is not inf # FP16 smallest positive subnormal works well here self.alphas_cumprod[-1] = 2**-24 # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": raise ValueError( f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." ) # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def set_timesteps( self, num_inference_steps: int = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`, and `timestep_spacing` attribute will be ignored. """ if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") if timesteps is not None and self.config.use_karras_sigmas: raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") if timesteps is not None and self.config.use_lu_lambdas: raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`") if timesteps is not None: timesteps = np.array(timesteps).astype(np.int64) else: # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, last_timestep - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = ( (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) ) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) if self.config.use_karras_sigmas: sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() elif self.config.use_lu_lambdas: lambdas = np.flip(log_sigmas.copy()) lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) sigmas = np.exp(lambdas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 elif self.config.final_sigmas_type == "zero": sigma_last = 0 else: raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 # add an index counter for schedulers that allow duplicated timesteps self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = 1 / ((sigma**2 + 1) ** 0.5) sigma_t = sigma * alpha_t return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Lu et al. (2022).""" lambda_min: float = in_lambdas[-1].item() lambda_max: float = in_lambdas[0].item() rho = 1.0 # 1.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = lambda_min ** (1 / rho) max_inv_rho = lambda_max ** (1 / rho) lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return lambdas def convert_model_output( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise prediction and data prediction models. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The converted model output. """ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError("missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: epsilon = model_output[:, :3] else: epsilon = model_output elif self.config.prediction_type == "sample": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t return epsilon def dpm_solver_first_order_update( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ One step for the first-order DPMSolver (equivalent to DDIM). Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None x_t = ( (alpha_t / alpha_s) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ One step for the second-order multistep DPMSolver. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing `sample` as a required keyward argument") if timestep_list is not None: deprecate( "timestep_list", "1.0.0", "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) m0, m1 = model_output_list[-1], model_output_list[-2] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * (torch.exp(h) - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """ One step for the third-order multistep DPMSolver. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. sample (`torch.Tensor`): A current instance of a sample created by diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing`sample` as a required keyward argument") if timestep_list is not None: deprecate( "timestep_list", "1.0.0", "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], self.sigmas[self.step_index - 2], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) elif len(index_candidates) > 1: step_index = index_candidates[1].item() else: step_index = index_candidates[0].item() return step_index def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the multistep DPMSolver. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. variance_noise (`torch.Tensor`): Alternative to generating noise with `generator` by directly providing the noise for the variance itself. Useful for methods such as [`LEdits++`]. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) or self.config.final_sigmas_type == "zero" ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 ) elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = variance_noise.to(device=model_output.device, dtype=torch.float32) else: noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) else: prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 # Cast sample back to expected dtype prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. Returns: `torch.Tensor`: A scaled input sample. """ return sample def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py ================================================ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver from dataclasses import dataclass from typing import List, Optional, Tuple, Union import flax import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, ) @flax.struct.dataclass class DPMSolverMultistepSchedulerState: common: CommonSchedulerState alpha_t: jnp.ndarray sigma_t: jnp.ndarray lambda_t: jnp.ndarray # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None # running values model_outputs: Optional[jnp.ndarray] = None lower_order_nums: Optional[jnp.int32] = None prev_timestep: Optional[jnp.int32] = None cur_sample: Optional[jnp.ndarray] = None @classmethod def create( cls, common: CommonSchedulerState, alpha_t: jnp.ndarray, sigma_t: jnp.ndarray, lambda_t: jnp.ndarray, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, ): return cls( common=common, alpha_t=alpha_t, sigma_t=sigma_t, lambda_t=lambda_t, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) @dataclass class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput): state: DPMSolverMultistepSchedulerState class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality samples, and it can generate quite good samples even in only 10 steps. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, or `v-prediction`. thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are slightly better, so we recommend to use the `midpoint` type. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, timestep_spacing: str = "linspace", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: if common is None: common = CommonSchedulerState.create(self) # Currently we only support VP-type noise schedule alpha_t = jnp.sqrt(common.alphas_cumprod) sigma_t = jnp.sqrt(1 - common.alphas_cumprod) lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) # settings for DPM-Solver if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]: raise NotImplementedError(f"{self.config.algorithm_type} is not implemented for {self.__class__}") if self.config.solver_type not in ["midpoint", "heun"]: raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}") # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return DPMSolverMultistepSchedulerState.create( common=common, alpha_t=alpha_t, sigma_t=sigma_t, lambda_t=lambda_t, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def set_timesteps( self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple ) -> DPMSolverMultistepSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. shape (`Tuple`): the shape of the samples to be generated. """ last_timestep = self.config.num_train_timesteps if self.config.timestep_spacing == "linspace": timesteps = ( jnp.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32) ) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = ( (jnp.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(jnp.int32) ) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = jnp.arange(last_timestep, 0, -step_ratio).round().copy().astype(jnp.int32) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) # initial running values model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) lower_order_nums = jnp.int32(0) prev_timestep = jnp.int32(-1) cur_sample = jnp.zeros(shape, dtype=self.dtype) return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, model_outputs=model_outputs, lower_order_nums=lower_order_nums, prev_timestep=prev_timestep, cur_sample=cur_sample, ) def convert_model_output( self, state: DPMSolverMultistepSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. So we need to first convert the model output to the corresponding type to match the algorithm. Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or DPM-Solver++ for both noise prediction model and data prediction model. Args: model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the converted model output. """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." ) if self.config.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 dynamic_max_val = jnp.percentile( jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) ) dynamic_max_val = jnp.maximum( dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) ) x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." ) def dpm_solver_first_order_update( self, state: DPMSolverMultistepSchedulerState, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ One step for the first-order DPM-Solver (equivalent to DDIM). See https://arxiv.org/abs/2206.00927 for the detailed derivation. Args: model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0 = prev_timestep, timestep m0 = model_output lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0] alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0] h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0 elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0 return x_t def multistep_dpm_solver_second_order_update( self, state: DPMSolverMultistepSchedulerState, model_output_list: jnp.ndarray, timestep_list: List[int], prev_timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ One step for the second-order multistep DPM-Solver. Args: model_output_list (`List[jnp.ndarray]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1] alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 ) return x_t def multistep_dpm_solver_third_order_update( self, state: DPMSolverMultistepSchedulerState, model_output_list: jnp.ndarray, timestep_list: List[int], prev_timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ One step for the third-order multistep DPM-Solver. Args: model_output_list (`List[jnp.ndarray]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1], state.lambda_t[s2], ) alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def step( self, state: DPMSolverMultistepSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class Returns: [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) (step_index,) = jnp.where(state.timesteps == timestep, size=1) step_index = step_index[0] prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1]) model_output = self.convert_model_output(state, model_output, timestep, sample) model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0) model_outputs_new = model_outputs_new.at[-1].set(model_output) state = state.replace( model_outputs=model_outputs_new, prev_timestep=prev_timestep, cur_sample=sample, ) def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: return self.dpm_solver_first_order_update( state, state.model_outputs[-1], state.timesteps[step_index], state.prev_timestep, state.cur_sample, ) def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]]) return self.multistep_dpm_solver_second_order_update( state, state.model_outputs, timestep_list, state.prev_timestep, state.cur_sample, ) def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: timestep_list = jnp.array( [ state.timesteps[step_index - 2], state.timesteps[step_index - 1], state.timesteps[step_index], ] ) return self.multistep_dpm_solver_third_order_update( state, state.model_outputs, timestep_list, state.prev_timestep, state.cur_sample, ) step_2_output = step_2(state) step_3_output = step_3(state) if self.config.solver_order == 2: return step_2_output elif self.config.lower_order_final and len(state.timesteps) < 15: return jax.lax.select( state.lower_order_nums < 2, step_2_output, jax.lax.select( step_index == len(state.timesteps) - 2, step_2_output, step_3_output, ), ) else: return jax.lax.select( state.lower_order_nums < 2, step_2_output, step_3_output, ) step_1_output = step_1(state) step_23_output = step_23(state) if self.config.solver_order == 1: prev_sample = step_1_output elif self.config.lower_order_final and len(state.timesteps) < 15: prev_sample = jax.lax.select( state.lower_order_nums < 1, step_1_output, jax.lax.select( step_index == len(state.timesteps) - 1, step_1_output, step_23_output, ), ) else: prev_sample = jax.lax.select( state.lower_order_nums < 1, step_1_output, step_23_output, ) state = state.replace( lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order), ) if not return_dict: return (prev_sample, state) return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) def scale_model_input( self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def add_noise( self, state: DPMSolverMultistepSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py ================================================ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): """ `DPMSolverMultistepInverseScheduler` is the reverse scheduler of [`DPMSolverMultistepScheduler`]. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. solver_order (`int`, defaults to 2): The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. euler_at_final (`bool`, defaults to `False`): Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference steps, but sometimes may result in blurring. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. variance_type (`str`, *optional*): Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output contains the predicted Gaussian variance. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self._step_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.use_karras_sigmas = use_karras_sigmas @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped).item() self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = (self.noisiest_timestep + 1) // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(self.noisiest_timestep + 1, 0, -step_ratio).round()[::-1].copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', " "'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_max = ( (1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep] ) ** 0.5 sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) timesteps = timesteps[np.sort(unique_indices)] self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 # add an index counter for schedulers that allow duplicated timesteps self._step_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = 1 / ((sigma**2 + 1) ** 0.5) sigma_t = sigma * alpha_t return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output def convert_model_output( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise prediction and data prediction models. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The converted model output. """ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError("missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: epsilon = model_output[:, :3] else: epsilon = model_output elif self.config.prediction_type == "sample": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t return epsilon # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update def dpm_solver_first_order_update( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ One step for the first-order DPMSolver (equivalent to DDIM). Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None x_t = ( (alpha_t / alpha_s) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ One step for the second-order multistep DPMSolver. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing `sample` as a required keyward argument") if timestep_list is not None: deprecate( "timestep_list", "1.0.0", "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) m0, m1 = model_output_list[-1], model_output_list[-2] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * (torch.exp(h) - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """ One step for the third-order multistep DPMSolver. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. sample (`torch.Tensor`): A current instance of a sample created by diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing`sample` as a required keyward argument") if timestep_list is not None: deprecate( "timestep_list", "1.0.0", "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], self.sigmas[self.step_index - 2], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def _init_step_index(self, timestep): if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) index_candidates = (self.timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) elif len(index_candidates) > 1: step_index = index_candidates[1].item() else: step_index = index_candidates[0].item() self._step_index = step_index def step( self, model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the multistep DPMSolver. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. variance_noise (`torch.Tensor`): Alternative to generating noise with `generator` by directly providing the noise for the variance itself. Useful for methods such as [`CycleDiffusion`]. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = variance_noise else: noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) else: prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. Returns: `torch.Tensor`: A scaled input sample. """ return sample def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [] for timestep in timesteps: index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(schedule_timesteps) - 1 elif len(index_candidates) > 1: step_index = index_candidates[1].item() else: step_index = index_candidates[0].item() step_indices.append(step_index) sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_dpmsolver_sde.py ================================================ # Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple, Union import numpy as np import torch import torchsde from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" def __init__(self, x, t0, t1, seed=None, **kwargs): t0, t1, self.sign = self.sort(t0, t1) w0 = kwargs.get("w0", torch.zeros_like(x)) if seed is None: seed = torch.randint(0, 2**63 - 1, []).item() self.batched = True try: assert len(seed) == x.shape[0] w0 = w0[0] except TypeError: seed = [seed] self.batched = False self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] @staticmethod def sort(a, b): return (a, b, 1) if a < b else (b, a, -1) def __call__(self, t0, t1): t0, t1, sign = self.sort(t0, t1) w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) return w if self.batched else w[0] class BrownianTreeNoiseSampler: """A noise sampler backed by a torchsde.BrownianTree. Args: x (Tensor): The tensor whose shape, device and dtype to use to generate random samples. sigma_min (float): The low end of the valid interval. sigma_max (float): The high end of the valid interval. seed (int or List[int]): The random seed. If a list of seeds is supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each with its own seed. transform (callable): A function that maps sigma to the sampler's internal timestep. """ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): self.transform = transform t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) self.tree = BatchedBrownianTree(x, t0, t1, seed) def __call__(self, sigma, sigma_next): t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) return self.tree(t0, t1) / (t1 - t0).abs().sqrt() # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): """ DPMSolverSDEScheduler implements the stochastic sampler from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.00085): The starting `beta` value of inference. beta_end (`float`, defaults to 0.012): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed is generated. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.use_karras_sigmas = use_karras_sigmas self.noise_sampler = None self.noise_sampler_seed = noise_sampler_seed self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_model_input( self, sample: torch.Tensor, timestep: Union[float, torch.Tensor], ) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma sample = sample / ((sigma_input**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) timesteps = torch.from_numpy(timesteps) second_order_timesteps = torch.from_numpy(second_order_timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) timesteps[1::2] = second_order_timesteps if str(device).startswith("mps"): # mps does not support float64 self.timesteps = timesteps.to(device, dtype=torch.float32) else: self.timesteps = timesteps.to(device=device) # empty first order variables self.sample = None self.mid_point_sigma = None self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.noise_sampler = None def _second_order_timesteps(self, sigmas, log_sigmas): def sigma_fn(_t): return np.exp(-_t) def t_fn(_sigma): return -np.log(_sigma) midpoint_ratio = 0.5 t = t_fn(sigmas) delta_time = np.diff(t) t_proposed = t[:-1] + delta_time * midpoint_ratio sig_proposed = sigma_fn(t_proposed) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed]) return timesteps # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, self.num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas @property def state_in_first_order(self): return self.sample is None def step( self, model_output: Union[torch.Tensor, np.ndarray], timestep: Union[float, torch.Tensor], sample: Union[torch.Tensor, np.ndarray], return_dict: bool = True, s_noise: float = 1.0, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor` or `np.ndarray`): The direct output from learned diffusion model. timestep (`float` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor` or `np.ndarray`): A current instance of a sample created by the diffusion process. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. s_noise (`float`, *optional*, defaults to 1.0): Scaling factor for noise added to the sample. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.step_index is None: self._init_step_index(timestep) # Create a noise sampler if it hasn't been created yet if self.noise_sampler is None: min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed) # Define functions to compute sigma and t from each other def sigma_fn(_t: torch.Tensor) -> torch.Tensor: return _t.neg().exp() def t_fn(_sigma: torch.Tensor) -> torch.Tensor: return _sigma.log().neg() if self.state_in_first_order: sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] else: # 2nd order sigma = self.sigmas[self.step_index - 1] sigma_next = self.sigmas[self.step_index] # Set the midpoint and step size for the current step midpoint_ratio = 0.5 t, t_next = t_fn(sigma), t_fn(sigma_next) delta_time = t_next - t t_proposed = t + delta_time * midpoint_ratio # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if sigma_next == 0: derivative = (sample - pred_original_sample) / sigma dt = sigma_next - sigma prev_sample = sample + derivative * dt else: if self.state_in_first_order: t_next = t_proposed else: sample = self.sample sigma_from = sigma_fn(t) sigma_to = sigma_fn(t_next) sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 ancestral_t = t_fn(sigma_down) prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( t - ancestral_t ).expm1() * pred_original_sample prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up if self.state_in_first_order: # store for 2nd order step self.sample = sample self.mid_point_sigma = sigma_fn(t_next) else: # free for "first order mode" self.sample = None self.mid_point_sigma = None # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py ================================================ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): """ `DPMSolverSinglestepScheduler` is a fast dedicated high-order solver for diffusion ODEs. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. solver_order (`int`, defaults to 2): The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. final_sigmas_type (`str`, *optional*, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. variance_type (`str`, *optional*): Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output contains the predicted Gaussian variance. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = False, use_karras_sigmas: Optional[bool] = False, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): if algorithm_type == "dpmsolver": deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": raise ValueError( f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead." ) # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.sample = None self.order_list = self.get_order_list(num_train_timesteps) self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication def get_order_list(self, num_inference_steps: int) -> List[int]: """ Computes the solver order at each time step. Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. """ steps = num_inference_steps order = self.config.solver_order if order > 3: raise ValueError("Order > 3 is not supported by this scheduler") if self.config.lower_order_final: if order == 3: if steps % 3 == 0: orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1] elif steps % 3 == 1: orders = [1, 2, 3] * (steps // 3) + [1] else: orders = [1, 2, 3] * (steps // 3) + [1, 2] elif order == 2: if steps % 2 == 0: orders = [1, 2] * (steps // 2 - 1) + [1, 1] else: orders = [1, 2] * (steps // 2) + [1] elif order == 1: orders = [1] * steps else: if order == 3: orders = [1, 2, 3] * (steps // 3) elif order == 2: orders = [1, 2] * (steps // 2) elif order == 1: orders = [1] * steps return orders @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def set_timesteps( self, num_inference_steps: int = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. """ if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") if num_inference_steps is not None and timesteps is not None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") if timesteps is not None and self.config.use_karras_sigmas: raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.") num_inference_steps = num_inference_steps or len(timesteps) self.num_inference_steps = num_inference_steps if timesteps is not None: timesteps = np.array(timesteps).astype(np.int64) else: # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 elif self.config.final_sigmas_type == "zero": sigma_last = 0 else: raise ValueError( f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}" ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.model_outputs = [None] * self.config.solver_order self.sample = None if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: logger.warning( "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`." ) self.register_to_config(lower_order_final=True) if not self.config.lower_order_final and self.config.final_sigmas_type == "zero": logger.warning( " `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True." ) self.register_to_config(lower_order_final=True) self.order_list = self.get_order_list(num_inference_steps) # add an index counter for schedulers that allow duplicated timesteps self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = 1 / ((sigma**2 + 1) ** 0.5) sigma_t = sigma * alpha_t return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def convert_model_output( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise prediction and data prediction models. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The converted model output. """ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError("missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverSinglestepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: epsilon = model_output[:, :3] else: epsilon = model_output elif self.config.prediction_type == "sample": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverSinglestepScheduler." ) if self.config.thresholding: alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t return epsilon def dpm_solver_first_order_update( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ One step for the first-order DPMSolver (equivalent to DDIM). Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. prev_timestep (`int`): The previous discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) return x_t def singlestep_dpm_solver_second_order_update( self, model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the time `timestep_list[-2]`. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): The current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): The previous discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing `sample` as a required keyward argument") if timestep_list is not None: deprecate( "timestep_list", "1.0.0", "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) m0, m1 = model_output_list[-1], model_output_list[-2] h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m1, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s1) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s1) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s1) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s1) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s1 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s1 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) return x_t def singlestep_dpm_solver_third_order_update( self, model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """ One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the time `timestep_list[-3]`. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): The current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): The previous discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing`sample` as a required keyward argument") if timestep_list is not None: deprecate( "timestep_list", "1.0.0", "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], self.sigmas[self.step_index - 2], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m2 D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2) D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1) D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s2) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s2) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s2) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s2) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def singlestep_dpm_solver_update( self, model_output_list: List[torch.Tensor], *args, sample: torch.Tensor = None, order: int = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ One step for the singlestep DPMSolver. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): The current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): The previous discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by diffusion process. order (`int`): The solver order at this step. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 2: sample = args[2] else: raise ValueError(" missing`sample` as a required keyward argument") if order is None: if len(args) > 3: order = args[3] else: raise ValueError(" missing `order` as a required keyward argument") if timestep_list is not None: deprecate( "timestep_list", "1.0.0", "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) if order == 1: return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise) elif order == 2: return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise) elif order == 3: return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample) else: raise ValueError(f"Order must be 1, 2, 3, got {order}") # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) elif len(index_candidates) > 1: step_index = index_candidates[1].item() else: step_index = index_candidates[0].item() return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the singlestep DPMSolver. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.algorithm_type == "sde-dpmsolver++": noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) else: noise = None order = self.order_list[self.step_index] # For img2img denoising might start with order>1 which is not possible # In this case make sure that the first two steps are both order=1 while self.model_outputs[-order] is None: order -= 1 # For single-step solvers, we use the initial value at each time with order = 1. if order == 1: self.sample = sample prev_sample = self.singlestep_dpm_solver_update( self.model_outputs, sample=self.sample, order=order, noise=noise ) # upon completion increase step index by one, noise=noise self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. Returns: `torch.Tensor`: A scaled input sample. """ return sample # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py ================================================ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils.torch_utils import randn_tensor from .scheduling_utils import SchedulerMixin, SchedulerOutput class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ Implements DPMSolverMultistepScheduler in EDM formulation as presented in Karras et al. 2022 [1]. `EDMDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: sigma_min (`float`, *optional*, defaults to 0.002): Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable range is [0, 10]. sigma_max (`float`, *optional*, defaults to 80.0): Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable range is [0.2, 80.0]. sigma_data (`float`, *optional*, defaults to 0.5): The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1]. sigma_schedule (`str`, *optional*, defaults to `karras`): Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl. num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. solver_order (`int`, defaults to 2): The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `dpmsolver++`): Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. solver_type (`str`, defaults to `midpoint`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. euler_at_final (`bool`, defaults to `False`): Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference steps, but sometimes may result in blurring. final_sigmas_type (`str`, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. """ _compatibles = [] order = 1 @register_to_config def __init__( self, sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, sigma_schedule: str = "karras", num_train_timesteps: int = 1000, prediction_type: str = "epsilon", rho: float = 7.0, solver_order: int = 2, thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, euler_at_final: bool = False, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" ): # settings for DPM-Solver if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": raise ValueError( f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." ) ramp = torch.linspace(0, 1, num_train_timesteps) if sigma_schedule == "karras": sigmas = self._compute_karras_sigmas(ramp) elif sigma_schedule == "exponential": sigmas = self._compute_exponential_sigmas(ramp) self.timesteps = self.precondition_noise(sigmas) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # setable values self.num_inference_steps = None self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def init_noise_sigma(self): # standard deviation of the initial noise distribution return (self.config.sigma_max**2 + 1) ** 0.5 @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs def precondition_inputs(self, sample, sigma): c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) scaled_sample = sample * c_in return scaled_sample # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise def precondition_noise(self, sigma): if not isinstance(sigma, torch.Tensor): sigma = torch.tensor([sigma]) c_noise = 0.25 * torch.log(sigma) return c_noise # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs def precondition_outputs(self, sample, model_output, sigma): sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) if self.config.prediction_type == "epsilon": c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 elif self.config.prediction_type == "v_prediction": c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 else: raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") denoised = c_skip * sample + c_out * model_output return denoised # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sample = self.precondition_inputs(sample, sigma) self.is_scale_input_called = True return sample def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps ramp = torch.linspace(0, 1, self.num_inference_steps) if self.config.sigma_schedule == "karras": sigmas = self._compute_karras_sigmas(ramp) elif self.config.sigma_schedule == "exponential": sigmas = self._compute_exponential_sigmas(ramp) sigmas = sigmas.to(dtype=torch.float32, device=device) self.timesteps = self.precondition_noise(sigmas) if self.config.final_sigmas_type == "sigma_min": sigma_last = self.config.sigma_min elif self.config.final_sigmas_type == "zero": sigma_last = 0 else: raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)]) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 # add an index counter for schedulers that allow duplicated timesteps self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max rho = self.config.rho min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: """Implementation closely follows k-diffusion. https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) return sigmas # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1 sigma_t = sigma return alpha_t, sigma_t def convert_model_output( self, model_output: torch.Tensor, sample: torch.Tensor = None, ) -> torch.Tensor: """ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise prediction and data prediction models. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The converted model output. """ sigma = self.sigmas[self.step_index] x0_pred = self.precondition_outputs(sample, model_output, sigma) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred def dpm_solver_first_order_update( self, model_output: torch.Tensor, sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ One step for the first-order DPMSolver (equivalent to DDIM). Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s = torch.log(alpha_s) - torch.log(sigma_s) h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) return x_t def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.Tensor], sample: torch.Tensor = None, noise: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ One step for the second-order multistep DPMSolver. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ sigma_t, sigma_s0, sigma_s1 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) m0, m1 = model_output_list[-1], model_output_list[-2] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) return x_t def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.Tensor], sample: torch.Tensor = None, ) -> torch.Tensor: """ One step for the third-order multistep DPMSolver. Args: model_output_list (`List[torch.Tensor]`): The direct outputs from learned diffusion model at current and latter timesteps. sample (`torch.Tensor`): A current instance of a sample created by diffusion process. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], self.sigmas[self.step_index - 1], self.sigmas[self.step_index - 2], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) elif len(index_candidates) > 1: step_index = index_candidates[1].item() else: step_index = index_candidates[0].item() return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the multistep DPMSolver. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) # Improve numerical stability for small number of steps lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) or self.config.final_sigmas_type == "zero" ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, sample=sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.algorithm_type == "sde-dpmsolver++": noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) else: noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) else: prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_edm_euler.py ================================================ # Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete class EDMEulerSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None class EDMEulerScheduler(SchedulerMixin, ConfigMixin): """ Implements the Euler scheduler in EDM formulation as presented in Karras et al. 2022 [1]. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: sigma_min (`float`, *optional*, defaults to 0.002): Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable range is [0, 10]. sigma_max (`float`, *optional*, defaults to 80.0): Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable range is [0.2, 80.0]. sigma_data (`float`, *optional*, defaults to 0.5): The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1]. sigma_schedule (`str`, *optional*, defaults to `karras`): Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl. num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). rho (`float`, *optional*, defaults to 7.0): The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1]. """ _compatibles = [] order = 1 @register_to_config def __init__( self, sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, sigma_schedule: str = "karras", num_train_timesteps: int = 1000, prediction_type: str = "epsilon", rho: float = 7.0, ): if sigma_schedule not in ["karras", "exponential"]: raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`") # setable values self.num_inference_steps = None ramp = torch.linspace(0, 1, num_train_timesteps) if sigma_schedule == "karras": sigmas = self._compute_karras_sigmas(ramp) elif sigma_schedule == "exponential": sigmas = self._compute_exponential_sigmas(ramp) self.timesteps = self.precondition_noise(sigmas) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self.is_scale_input_called = False self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def init_noise_sigma(self): # standard deviation of the initial noise distribution return (self.config.sigma_max**2 + 1) ** 0.5 @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def precondition_inputs(self, sample, sigma): c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) scaled_sample = sample * c_in return scaled_sample def precondition_noise(self, sigma): if not isinstance(sigma, torch.Tensor): sigma = torch.tensor([sigma]) c_noise = 0.25 * torch.log(sigma) return c_noise def precondition_outputs(self, sample, model_output, sigma): sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) if self.config.prediction_type == "epsilon": c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 elif self.config.prediction_type == "v_prediction": c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 else: raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") denoised = c_skip * sample + c_out * model_output return denoised def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sample = self.precondition_inputs(sample, sigma) self.is_scale_input_called = True return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps ramp = torch.linspace(0, 1, self.num_inference_steps) if self.config.sigma_schedule == "karras": sigmas = self._compute_karras_sigmas(ramp) elif self.config.sigma_schedule == "exponential": sigmas = self._compute_exponential_sigmas(ramp) sigmas = sigmas.to(dtype=torch.float32, device=device) self.timesteps = self.precondition_noise(sigmas) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max rho = self.config.rho min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: """Implementation closely follows k-diffusion. https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) return sigmas # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: Union[float, torch.Tensor], sample: torch.Tensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[EDMEulerSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. s_churn (`float`): s_tmin (`float`): s_tmax (`float`): s_noise (`float`, defaults to 1.0): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EDMEulerScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if self.step_index is None: self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) eps = noise * s_noise sigma_hat = sigma * (gamma + 1) if gamma > 0: sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma_hat dt = self.sigmas[self.step_index + 1] - sigma_hat prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py ================================================ # Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete class EulerAncestralDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Ancestral sampling with Euler method steps. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) if rescale_betas_zero_snr: # Close to 0 without being 0 so first sigma is not inf # FP16 smallest positive subnormal works well here self.alphas_cumprod[-1] = 2**-24 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sample = sample / ((sigma**2 + 1) ** 0.5) self.is_scale_input_called = True return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ ::-1 ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device) self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: Union[float, torch.Tensor], sample: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) sigma_from = self.sigmas[self.step_index] sigma_to = self.sigmas[self.step_index + 1] sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma dt = sigma_down - sigma prev_sample = sample + derivative * dt device = model_output.device noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) prev_sample = prev_sample + noise * sigma_up # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return EulerAncestralDiscreteSchedulerOutput( prev_sample=prev_sample, pred_original_sample=pred_original_sample ) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_euler_discrete.py ================================================ # Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete class EulerDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Euler scheduler. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). interpolation_type(`str`, defaults to `"linear"`, *optional*): The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of `"linear"` or `"log_linear"`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). final_sigmas_type (`str`, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, timestep_spacing: str = "linspace", timestep_type: str = "discrete", # can be "discrete" or "continuous" steps_offset: int = 0, rescale_betas_zero_snr: bool = False, final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) if rescale_betas_zero_snr: # Close to 0 without being 0 so first sigma is not inf # FP16 smallest positive subnormal works well here self.alphas_cumprod[-1] = 2**-24 sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0) timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) # setable values self.num_inference_steps = None # TODO: Support the full EDM scalings for all prediction types and timestep types if timestep_type == "continuous" and prediction_type == "v_prediction": self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]) else: self.timesteps = timesteps self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def init_noise_sigma(self): # standard deviation of the initial noise distribution max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max() if self.config.timestep_spacing in ["linspace", "trailing"]: return max_sigma return (max_sigma**2 + 1) ** 0.5 @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sample = sample / ((sigma**2 + 1) ** 0.5) self.is_scale_input_called = True return sample def set_timesteps( self, num_inference_steps: int = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`, and `timestep_spacing` attribute will be ignored. sigmas (`List[float]`, *optional*): Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas will be generated based on the relevant scheduler attributes. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the custom sigmas schedule. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` should be set.") if num_inference_steps is None and timesteps is None and sigmas is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.") if num_inference_steps is not None and (timesteps is not None or sigmas is not None): raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.") if timesteps is not None and self.config.use_karras_sigmas: raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") if ( timesteps is not None and self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction" ): raise ValueError( "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." ) if num_inference_steps is None: num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1 self.num_inference_steps = num_inference_steps if sigmas is not None: log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)) sigmas = np.array(sigmas).astype(np.float32) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]]) else: if timesteps is not None: timesteps = np.array(timesteps).astype(np.float32) else: # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace( 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32 )[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = ( (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) ) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = ( (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) ) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) if self.config.interpolation_type == "linear": sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) elif self.config.interpolation_type == "log_linear": sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy() else: raise ValueError( f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" " 'linear' or 'log_linear'" ) if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 elif self.config.final_sigmas_type == "zero": sigma_last = 0 else: raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) # TODO: Support the full EDM scalings for all prediction types and timestep types if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device) else: self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) self._step_index = None self._begin_index = None self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: Union[float, torch.Tensor], sample: torch.Tensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. s_churn (`float`): s_tmin (`float`): s_tmax (`float`): s_noise (`float`, defaults to 1.0): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if self.step_index is None: self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) eps = noise * s_noise sigma_hat = sigma * (gamma + 1) if gamma > 0: sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # NOTE: "original_sample" should not be an expected prediction_type but is left in for # backwards compatibility if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma_hat * model_output elif self.config.prediction_type == "v_prediction": # denoised = model_output * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma_hat dt = self.sigmas[self.step_index + 1] - sigma_hat prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: if ( isinstance(timesteps, int) or isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if sample.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) timesteps = timesteps.to(sample.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(sample.device) timesteps = timesteps.to(sample.device) step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] alphas_cumprod = self.alphas_cumprod.to(sample) sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_euler_discrete_flax.py ================================================ # Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, ) @flax.struct.dataclass class EulerDiscreteSchedulerState: common: CommonSchedulerState # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray sigmas: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod def create( cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray ): return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) @dataclass class FlaxEulerDiscreteSchedulerOutput(FlaxSchedulerOutput): state: EulerDiscreteSchedulerState class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): """ Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState: if common is None: common = CommonSchedulerState.create(self) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: init_noise_sigma = sigmas.max() else: init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5 return EulerDiscreteSchedulerState.create( common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas, ) def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: state (`EulerDiscreteSchedulerState`): the `FlaxEulerDiscreteScheduler` state data class instance. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. timestep (`int`): current discrete timestep in the diffusion chain. Returns: `jnp.ndarray`: scaled input sample """ (step_index,) = jnp.where(state.timesteps == timestep, size=1) step_index = step_index[0] sigma = state.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def set_timesteps( self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> EulerDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`EulerDiscreteSchedulerState`): the `FlaxEulerDiscreteScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ if self.config.timestep_spacing == "linspace": timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // num_inference_steps timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += 1 else: raise ValueError( f"timestep_spacing must be one of ['linspace', 'leading'], got {self.config.timestep_spacing}" ) sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5 sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas) sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: init_noise_sigma = sigmas.max() else: init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5 return state.replace( timesteps=timesteps, sigmas=sigmas, num_inference_steps=num_inference_steps, init_noise_sigma=init_noise_sigma, ) def step( self, state: EulerDiscreteSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxEulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`EulerDiscreteSchedulerState`): the `FlaxEulerDiscreteScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. return_dict (`bool`): option for returning tuple rather than FlaxEulerDiscreteScheduler class Returns: [`FlaxEulerDiscreteScheduler`] or `tuple`: [`FlaxEulerDiscreteScheduler`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) (step_index,) = jnp.where(state.timesteps == timestep, size=1) step_index = step_index[0] sigma = state.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma # dt = sigma_down - sigma dt = state.sigmas[step_index + 1] - sigma prev_sample = sample + derivative * dt if not return_dict: return (prev_sample, state) return FlaxEulerDiscreteSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, state: EulerDiscreteSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: sigma = state.sigmas[timesteps].flatten() sigma = broadcast_to_shape_from_left(sigma, noise.shape) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py ================================================ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from .scheduling_utils import SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.FloatTensor class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Euler scheduler. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. shift (`float`, defaults to 1.0): The shift value for the timestep schedule. """ _compatibles = [] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, shift: float = 1.0, use_dynamic_shifting=False, base_shift: Optional[float] = 0.5, max_shift: Optional[float] = 1.15, base_image_seq_len: Optional[int] = 256, max_image_seq_len: Optional[int] = 4096, ): timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps if not use_dynamic_shifting: # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.timesteps = sigmas * num_train_timesteps self._step_index = None self._begin_index = None self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_noise( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ Forward process in flow-matching Args: sample (`torch.FloatTensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.FloatTensor`: A scaled input sample. """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) if sample.device.type == "mps" and torch.is_floating_point(timestep): # mps does not support float64 schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) timestep = timestep.to(sample.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(sample.device) timestep = timestep.to(sample.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timestep.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timestep.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(sample.shape): sigma = sigma.unsqueeze(-1) sample = sigma * noise + (1.0 - sigma) * sample return sample def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps def time_shift(self, mu: float, sigma: float, t: torch.Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def set_timesteps( self, num_inference_steps: int = None, device: Union[str, torch.device] = None, sigmas: Optional[List[float]] = None, mu: Optional[float] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ if self.config.use_dynamic_shifting and mu is None: raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") if sigmas is None: self.num_inference_steps = num_inference_steps timesteps = np.linspace( self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps ) sigmas = timesteps / self.config.num_train_timesteps if self.config.use_dynamic_shifting: sigmas = self.time_shift(mu, 1.0, sigmas) else: sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps self.timesteps = timesteps.to(device=device) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self._step_index = None self._begin_index = None def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. s_churn (`float`): s_tmin (`float`): s_tmax (`float`): s_noise (`float`, defaults to 1.0): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if self.step_index is None: self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] prev_sample = sample + (sigma_next - sigma) * model_output # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) def step_forward( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: """ Forward invrsion step """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if self.step_index is None: self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] next_sample = sample + (sigma - sigma_next) * model_output # Cast sample back to model compatible dtype next_sample = next_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index -= 1 return (next_sample,) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py ================================================ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.FloatTensor class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Heun scheduler. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. shift (`float`, defaults to 1.0): The shift value for the timestep schedule. """ _compatibles = [] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, shift: float = 1.0, ): timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.timesteps = sigmas * num_train_timesteps self._step_index = None self._begin_index = None self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_noise( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ Forward process in flow-matching Args: sample (`torch.FloatTensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.FloatTensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sample = sigma * noise + (1.0 - sigma) * sample return sample def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps timesteps = np.linspace( self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps ) sigmas = timesteps / self.config.num_train_timesteps sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) self.timesteps = timesteps.to(device=device) sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) # empty dt and derivative self.prev_derivative = None self.dt = None self._step_index = None self._begin_index = None def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index @property def state_in_first_order(self): return self.dt is None def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. s_churn (`float`): s_tmin (`float`): s_tmax (`float`): s_noise (`float`, defaults to 1.0): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if self.step_index is None: self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) if self.state_in_first_order: sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] else: # 2nd order / Heun's method sigma = self.sigmas[self.step_index - 1] sigma_next = self.sigmas[self.step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) eps = noise * s_noise sigma_hat = sigma * (gamma + 1) if gamma > 0: sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 if self.state_in_first_order: # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise denoised = sample - model_output * sigma # 2. convert to an ODE derivative for 1st order derivative = (sample - denoised) / sigma_hat # 3. Delta timestep dt = sigma_next - sigma_hat # store for 2nd order step self.prev_derivative = derivative self.dt = dt self.sample = sample else: # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise denoised = sample - model_output * sigma_next # 2. 2nd order / Heun's method derivative = (sample - denoised) / sigma_next derivative = 0.5 * (self.prev_derivative + derivative) # 3. take prev timestep & sample dt = self.dt sample = self.sample # free dt and derivative # Note, this puts the scheduler in "first order mode" self.prev_derivative = None self.dt = None self.sample = None prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_heun_discrete.py ================================================ # Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Scheduler with Heun steps for discrete beta schedules. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). clip_sample (`bool`, defaults to `True`): Clip the predicted sample for numerical stability. clip_sample_range (`float`, defaults to 1.0): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, clip_sample: Optional[bool] = False, clip_sample_range: float = 1.0, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine") elif beta_schedule == "exp": self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp") else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.use_karras_sigmas = use_karras_sigmas self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_model_input( self, sample: torch.Tensor, timestep: Union[float, torch.Tensor], ) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. num_train_timesteps (`int`, *optional*): The number of diffusion steps used when training the model. If `None`, the default `num_train_timesteps` attribute is used. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` must be `None`, and `timestep_spacing` attribute will be ignored. """ if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") if timesteps is not None and self.config.use_karras_sigmas: raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") num_inference_steps = num_inference_steps or len(timesteps) self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps if timesteps is not None: timesteps = np.array(timesteps, dtype=np.float32) else: # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) timesteps = torch.from_numpy(timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) self.timesteps = timesteps.to(device=device) # empty dt and derivative self.prev_derivative = None self.dt = None self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas @property def state_in_first_order(self): return self.dt is None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: Union[torch.Tensor, np.ndarray], timestep: Union[float, torch.Tensor], sample: Union[torch.Tensor, np.ndarray], return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.step_index is None: self._init_step_index(timestep) if self.state_in_first_order: sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] else: # 2nd order / Heun's method sigma = self.sigmas[self.step_index - 1] sigma_next = self.sigmas[self.step_index] # currently only gamma=0 is supported. This usually works best anyways. # We can support gamma in the future but then need to scale the timestep before # passing it to the model which requires a change in API gamma = 0 sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma_hat if self.state_in_first_order else sigma_next pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma_hat if self.state_in_first_order else sigma_next pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) if self.state_in_first_order: # 2. Convert to an ODE derivative for 1st order derivative = (sample - pred_original_sample) / sigma_hat # 3. delta timestep dt = sigma_next - sigma_hat # store for 2nd order step self.prev_derivative = derivative self.dt = dt self.sample = sample else: # 2. 2nd order / Heun's method derivative = (sample - pred_original_sample) / sigma_next derivative = (self.prev_derivative + derivative) / 2 # 3. take prev timestep & sample dt = self.dt sample = self.sample # free dt and derivative # Note, this puts the scheduler in "first order mode" self.prev_derivative = None self.dt = None self.sample = None prev_sample = sample + derivative * dt # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_ipndm.py ================================================ # Copyright 2024 Zhejiang University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin, SchedulerOutput class IPNDMScheduler(SchedulerMixin, ConfigMixin): """ A fourth-order Improved Pseudo Linear Multistep scheduler. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None ): # set `betas`, `alphas`, `timesteps` self.set_timesteps(num_train_timesteps) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 # running values self.ets = [] self._step_index = None self._begin_index = None @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1] steps = torch.cat([steps, torch.tensor([0.0])]) if self.config.trained_betas is not None: self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32) else: self.betas = torch.sin(steps * math.pi / 2) ** 2 self.alphas = (1.0 - self.betas**2) ** 0.5 timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1] self.timesteps = timesteps.to(device) self.ets = [] self._step_index = None self._begin_index = None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the linear multistep method. It performs one forward pass multiple times to approximate the solution. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) timestep_index = self.step_index prev_timestep_index = self.step_index + 1 ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index] self.ets.append(ets) if len(self.ets) == 1: ets = self.ets[-1] elif len(self.ets) == 2: ets = (3 * self.ets[-1] - self.ets[-2]) / 2 elif len(self.ets) == 3: ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 else: ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets) # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. Returns: `torch.Tensor`: A scaled input sample. """ return sample def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets): alpha = self.alphas[timestep_index] sigma = self.betas[timestep_index] next_alpha = self.alphas[prev_timestep_index] next_sigma = self.betas[prev_timestep_index] pred = (sample - sigma * ets) / max(alpha, 1e-8) prev_sample = next_alpha * pred + ets * next_sigma return prev_sample def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py ================================================ # Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ KDPM2DiscreteScheduler with ancestral sampling is inspired by the DPMSolver2 and Algorithm 2 from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.00085): The starting `beta` value of inference. beta_end (`float`, defaults to 0.012): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_model_input( self, sample: torch.Tensor, timestep: Union[float, torch.Tensor], ) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) if self.state_in_first_order: sigma = self.sigmas[self.step_index] else: sigma = self.sigmas_interpol[self.step_index - 1] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() self.log_sigmas = torch.from_numpy(log_sigmas).to(device) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) # compute up and down sigmas sigmas_next = sigmas.roll(-1) sigmas_next[-1] = 0.0 sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5 sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5 sigmas_down[-1] = 0.0 # compute interpolated sigmas sigmas_interpol = sigmas.log().lerp(sigmas_down.log(), 0.5).exp() sigmas_interpol[-2:] = 0.0 # set sigmas self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) self.sigmas_interpol = torch.cat( [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] ) self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) if str(device).startswith("mps"): timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: timesteps = torch.from_numpy(timesteps).to(device) sigmas_interpol = sigmas_interpol.cpu() log_sigmas = self.log_sigmas.cpu() timesteps_interpol = np.array( [self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol] ) timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.sample = None self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas @property def state_in_first_order(self): return self.sample is None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: Union[torch.Tensor, np.ndarray], timestep: Union[float, torch.Tensor], sample: Union[torch.Tensor, np.ndarray], generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.step_index is None: self._init_step_index(timestep) if self.state_in_first_order: sigma = self.sigmas[self.step_index] sigma_interpol = self.sigmas_interpol[self.step_index] sigma_up = self.sigmas_up[self.step_index] sigma_down = self.sigmas_down[self.step_index - 1] else: # 2nd order / KPDM2's method sigma = self.sigmas[self.step_index - 1] sigma_interpol = self.sigmas_interpol[self.step_index - 1] sigma_up = self.sigmas_up[self.step_index - 1] sigma_down = self.sigmas_down[self.step_index - 1] # currently only gamma=0 is supported. This usually works best anyways. # We can support gamma in the future but then need to scale the timestep before # passing it to the model which requires a change in API gamma = 0 sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now device = model_output.device noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if self.state_in_first_order: # 2. Convert to an ODE derivative for 1st order derivative = (sample - pred_original_sample) / sigma_hat # 3. delta timestep dt = sigma_interpol - sigma_hat # store for 2nd order step self.sample = sample self.dt = dt prev_sample = sample + derivative * dt else: # DPM-Solver-2 # 2. Convert to an ODE derivative for 2nd order derivative = (sample - pred_original_sample) / sigma_interpol # 3. delta timestep dt = sigma_down - sigma_hat sample = self.sample self.sample = None prev_sample = sample + derivative * dt prev_sample = prev_sample + noise * sigma_up # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py ================================================ # Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): """ KDPM2DiscreteScheduler is inspired by the DPMSolver2 and Algorithm 2 from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.00085): The starting `beta` value of inference. beta_end (`float`, defaults to 0.012): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_model_input( self, sample: torch.Tensor, timestep: Union[float, torch.Tensor], ) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) if self.state_in_first_order: sigma = self.sigmas[self.step_index] else: sigma = self.sigmas_interpol[self.step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) # interpolate sigmas sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp() self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) self.sigmas_interpol = torch.cat( [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] ) timesteps = torch.from_numpy(timesteps).to(device) # interpolate timesteps sigmas_interpol = sigmas_interpol.cpu() log_sigmas = self.log_sigmas.cpu() timesteps_interpol = np.array( [self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol] ) timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.sample = None self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def state_in_first_order(self): return self.sample is None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def step( self, model_output: Union[torch.Tensor, np.ndarray], timestep: Union[float, torch.Tensor], sample: Union[torch.Tensor, np.ndarray], return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.step_index is None: self._init_step_index(timestep) if self.state_in_first_order: sigma = self.sigmas[self.step_index] sigma_interpol = self.sigmas_interpol[self.step_index + 1] sigma_next = self.sigmas[self.step_index + 1] else: # 2nd order / KDPM2's method sigma = self.sigmas[self.step_index - 1] sigma_interpol = self.sigmas_interpol[self.step_index] sigma_next = self.sigmas[self.step_index] # currently only gamma=0 is supported. This usually works best anyways. # We can support gamma in the future but then need to scale the timestep before # passing it to the model which requires a change in API gamma = 0 sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if self.state_in_first_order: # 2. Convert to an ODE derivative for 1st order derivative = (sample - pred_original_sample) / sigma_hat # 3. delta timestep dt = sigma_interpol - sigma_hat # store for 2nd order step self.sample = sample else: # DPM-Solver-2 # 2. Convert to an ODE derivative for 2nd order derivative = (sample - pred_original_sample) / sigma_interpol # 3. delta timestep dt = sigma_next - sigma_hat sample = self.sample self.sample = None # upon completion increase step index by one self._step_index += 1 prev_sample = sample + derivative * dt if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_karras_ve_flax.py ================================================ # Copyright 2024 NVIDIA and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax import jax.numpy as jnp from jax import random from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .scheduling_utils_flax import FlaxSchedulerMixin @flax.struct.dataclass class KarrasVeSchedulerState: # setable values num_inference_steps: Optional[int] = None timesteps: Optional[jnp.ndarray] = None schedule: Optional[jnp.ndarray] = None # sigma(t_i) @classmethod def create(cls): return cls() @dataclass class FlaxKarrasVeOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Derivative of predicted original image sample (x_0). state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. """ prev_sample: jnp.ndarray derivative: jnp.ndarray state: KarrasVeSchedulerState class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. Args: sigma_min (`float`): minimum noise magnitude sigma_max (`float`): maximum noise magnitude s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 1.011]. s_churn (`float`): the parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100]. s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). A reasonable range is [0, 10]. s_max (`float`): the end value of the sigma range where we add noise. A reasonable range is [0.2, 80]. """ @property def has_state(self): return True @register_to_config def __init__( self, sigma_min: float = 0.02, sigma_max: float = 100, s_noise: float = 1.007, s_churn: float = 80, s_min: float = 0.05, s_max: float = 50, ): pass def create_state(self): return KarrasVeSchedulerState.create() def set_timesteps( self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> KarrasVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ timesteps = jnp.arange(0, num_inference_steps)[::-1].copy() schedule = [ ( self.config.sigma_max**2 * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) ) for i in timesteps ] return state.replace( num_inference_steps=num_inference_steps, schedule=jnp.array(schedule, dtype=jnp.float32), timesteps=timesteps, ) def add_noise_to_input( self, state: KarrasVeSchedulerState, sample: jnp.ndarray, sigma: float, key: jax.Array, ) -> Tuple[jnp.ndarray, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. TODO Args: """ if self.config.s_min <= sigma <= self.config.s_max: gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1) else: gamma = 0 # sample eps ~ N(0, S_noise^2 * I) key = random.split(key, num=1) eps = self.config.s_noise * random.normal(key=key, shape=sample.shape) sigma_hat = sigma + gamma * sigma sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) return sample_hat, sigma_hat def step( self, state: KarrasVeSchedulerState, model_output: jnp.ndarray, sigma_hat: float, sigma_prev: float, sample_hat: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxKarrasVeOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.Tensor` or `np.ndarray`): TODO return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class Returns: [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ pred_original_sample = sample_hat + sigma_hat * model_output derivative = (sample_hat - pred_original_sample) / sigma_hat sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative if not return_dict: return (sample_prev, derivative, state) return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) def step_correct( self, state: KarrasVeSchedulerState, model_output: jnp.ndarray, sigma_hat: float, sigma_prev: float, sample_hat: jnp.ndarray, sample_prev: jnp.ndarray, derivative: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxKarrasVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. TODO complete description Args: state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.Tensor` or `np.ndarray`): TODO sample_prev (`torch.Tensor` or `np.ndarray`): TODO derivative (`torch.Tensor` or `np.ndarray`): TODO return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO """ pred_original_sample = sample_prev + sigma_prev * model_output derivative_corr = (sample_prev - pred_original_sample) / sigma_prev sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) if not return_dict: return (sample_prev, derivative, state) return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps): raise NotImplementedError() ================================================ FILE: src/diffusers/schedulers/scheduling_lcm.py ================================================ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class LCMSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor denoised: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class LCMScheduler(SchedulerMixin, ConfigMixin): """ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. original_inference_steps (`int`, *optional*, defaults to 50): The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule. clip_sample (`bool`, defaults to `True`): Clip the predicted sample for numerical stability. clip_sample_range (`float`, defaults to 1.0): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, defaults to `True`): Each diffusion step uses the alphas product value at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the alpha value at step 0. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. timestep_scaling (`float`, defaults to 10.0): The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation error at the default of `10.0` is already pretty small). rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "scaled_linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, original_inference_steps: int = 50, clip_sample: bool = False, clip_sample_range: float = 1.0, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, timestep_spacing: str = "leading", timestep_scaling: float = 10.0, rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) self.custom_timesteps = False self._step_index = None self._begin_index = None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index @property def step_index(self): return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, original_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, strength: int = 1.0, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. original_inference_steps (`int`, *optional*): The original number of inference steps, which will be used to generate a linearly-spaced timestep schedule (which is different from the standard `diffusers` implementation). We will then take `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. """ # 0. Check inputs if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") # 1. Calculate the LCM original training/distillation timestep schedule. original_steps = ( original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps ) if original_steps > self.config.num_train_timesteps: raise ValueError( f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) # LCM Timesteps Setting # The skipping step parameter k from the paper. k = self.config.num_train_timesteps // original_steps # LCM Training/Distillation Steps Schedule # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts). lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1 # 2. Calculate the LCM inference timestep schedule. if timesteps is not None: # 2.1 Handle custom timestep schedules. train_timesteps = set(lcm_origin_timesteps) non_train_timesteps = [] for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`custom_timesteps` must be in descending order.") if timesteps[i] not in train_timesteps: non_train_timesteps.append(timesteps[i]) if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.config.num_train_timesteps}." ) # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1 if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1: logger.warning( f"The first timestep on the custom timestep schedule is {timesteps[0]}, not" f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get" f" unexpected results when using this timestep schedule." ) # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule if non_train_timesteps: logger.warning( f"The custom timestep schedule contains the following timesteps which are not on the original" f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results" f" when using this timestep schedule." ) # Raise warning if custom timestep schedule is longer than original_steps if len(timesteps) > original_steps: logger.warning( f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the" f" the length of the timestep schedule used for training: {original_steps}. You may get some" f" unexpected results when using this timestep schedule." ) timesteps = np.array(timesteps, dtype=np.int64) self.num_inference_steps = len(timesteps) self.custom_timesteps = True # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps) init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps) t_start = max(self.num_inference_steps - init_timestep, 0) timesteps = timesteps[t_start * self.order :] # TODO: also reset self.num_inference_steps? else: # 2.2 Create the "standard" LCM inference timestep schedule. if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) skipping_step = len(lcm_origin_timesteps) // num_inference_steps if skipping_step < 1: raise ValueError( f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}." ) self.num_inference_steps = num_inference_steps if num_inference_steps > original_steps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:" f" {original_steps} because the final timestep schedule will be a subset of the" f" `original_inference_steps`-sized initial timestep schedule." ) # LCM Inference Steps Schedule lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy() # Select (approximately) evenly spaced indices from lcm_origin_timesteps. inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False) inference_indices = np.floor(inference_indices).astype(np.int64) timesteps = lcm_origin_timesteps[inference_indices] self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long) self._step_index = None self._begin_index = None def get_scalings_for_boundary_condition_discrete(self, timestep): self.sigma_data = 0.5 # Default: 0.5 scaled_timestep = timestep * self.config.timestep_scaling c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2) c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5 return c_skip, c_out def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[LCMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) # 1. get previous step value prev_step_index = self.step_index + 1 if prev_step_index < len(self.timesteps): prev_timestep = self.timesteps[prev_step_index] else: prev_timestep = timestep # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # 3. Get scalings for boundary conditions c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) # 4. Compute the predicted original sample x_0 based on the model parameterization if self.config.prediction_type == "epsilon": # noise-prediction predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() elif self.config.prediction_type == "sample": # x-prediction predicted_original_sample = model_output elif self.config.prediction_type == "v_prediction": # v-prediction predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for `LCMScheduler`." ) # 5. Clip or threshold "predicted x_0" if self.config.thresholding: predicted_original_sample = self._threshold_sample(predicted_original_sample) elif self.config.clip_sample: predicted_original_sample = predicted_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 6. Denoise model output using boundary conditions denoised = c_out * predicted_original_sample + c_skip * sample # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference # Noise is not used on the final timestep of the timestep schedule. # This also means that noise is not used for one-step sampling. if self.step_index != self.num_inference_steps - 1: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype ) prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise else: prev_sample = denoised # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample, denoised) return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): if self.custom_timesteps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: num_inference_steps = ( self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps ) prev_t = timestep - self.config.num_train_timesteps // num_inference_steps return prev_t ================================================ FILE: src/diffusers/schedulers/scheduling_lms_discrete.py ================================================ # Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import warnings from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete class LMSDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): """ A linear multistep scheduler for discrete beta schedules. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None self.use_karras_sigmas = use_karras_sigmas self.set_timesteps(num_train_timesteps, None) self.derivatives = [] self.is_scale_input_called = False self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sample = sample / ((sigma**2 + 1) ** 0.5) self.is_scale_input_called = True return sample def get_lms_coefficient(self, order, t, current_order): """ Compute the linear multistep coefficient. Args: order (): t (): current_order (): """ def lms_derivative(tau): prod = 1.0 for k in range(order): if current_order == k: continue prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k]) return prod integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0] return integrated_coeff def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ ::-1 ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device) self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.derivatives = [] # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, self.num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def step( self, model_output: torch.Tensor, timestep: Union[float, torch.Tensor], sample: torch.Tensor, order: int = 4, return_dict: bool = True, ) -> Union[LMSDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. order (`int`, defaults to 4): The order of the linear multistep method. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if not self.is_scale_input_called: warnings.warn( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) elif self.config.prediction_type == "sample": pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma self.derivatives.append(derivative) if len(self.derivatives) > order: self.derivatives.pop(0) # 3. Compute linear multistep coefficients order = min(self.step_index + 1, order) lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)] # 4. Compute previous sample based on the derivatives path prev_sample = sample + sum( coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) ) # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_lms_discrete_flax.py ================================================ # Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, ) @flax.struct.dataclass class LMSDiscreteSchedulerState: common: CommonSchedulerState # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray sigmas: jnp.ndarray num_inference_steps: Optional[int] = None # running values derivatives: Optional[jnp.ndarray] = None @classmethod def create( cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray ): return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) @dataclass class FlaxLMSSchedulerOutput(FlaxSchedulerOutput): state: LMSDiscreteSchedulerState class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState: if common is None: common = CommonSchedulerState.create(self) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 # standard deviation of the initial noise distribution init_noise_sigma = sigmas.max() return LMSDiscreteSchedulerState.create( common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas, ) def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. Args: state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. timestep (`int`): current discrete timestep in the diffusion chain. Returns: `jnp.ndarray`: scaled input sample """ (step_index,) = jnp.where(state.timesteps == timestep, size=1) step_index = step_index[0] sigma = state.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order): """ Compute a linear multistep coefficient. Args: order (TODO): t (TODO): current_order (TODO): """ def lms_derivative(tau): prod = 1.0 for k in range(order): if current_order == k: continue prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k]) return prod integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0] return integrated_coeff def set_timesteps( self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) low_idx = jnp.floor(timesteps).astype(jnp.int32) high_idx = jnp.ceil(timesteps).astype(jnp.int32) frac = jnp.mod(timesteps, 1.0) sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5 sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) timesteps = timesteps.astype(jnp.int32) # initial running values derivatives = jnp.zeros((0,) + shape, dtype=self.dtype) return state.replace( timesteps=timesteps, sigmas=sigmas, num_inference_steps=num_inference_steps, derivatives=derivatives, ) def step( self, state: LMSDiscreteSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, order: int = 4, return_dict: bool = True, ) -> Union[FlaxLMSSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class Returns: [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) sigma = state.sigmas[timestep] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma state = state.replace(derivatives=jnp.append(state.derivatives, derivative)) if len(state.derivatives) > order: state = state.replace(derivatives=jnp.delete(state.derivatives, 0)) # 3. Compute linear multistep coefficients order = min(timestep + 1, order) lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)] # 4. Compute previous sample based on the derivatives path prev_sample = sample + sum( coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives)) ) if not return_dict: return (prev_sample, state) return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, state: LMSDiscreteSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: sigma = state.sigmas[timesteps].flatten() sigma = broadcast_to_shape_from_left(sigma, noise.shape) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_pndm.py ================================================ # Copyright 2024 Zhejiang University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class PNDMScheduler(SchedulerMixin, ConfigMixin): """ `PNDMScheduler` uses pseudo numerical methods for diffusion models such as the Runge-Kutta and linear multi-step method. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. skip_prk_steps (`bool`, defaults to `False`): Allows the scheduler to skip the Runge-Kutta steps defined in the original paper as being required before PLMS steps. set_alpha_to_one (`bool`, defaults to `False`): Each diffusion step uses the alphas product value at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the alpha value at step 0. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, skip_prk_steps: bool = False, set_alpha_to_one: bool = False, prediction_type: str = "epsilon", timestep_spacing: str = "leading", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 # running values self.cur_model_output = 0 self.counter = 0 self.cur_sample = None self.ets = [] # setable values self.num_inference_steps = None self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.prk_timesteps = None self.plms_timesteps = None self.timesteps = None def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": self._timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() self._timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype( np.int64 ) self._timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 self.prk_timesteps = np.array([]) self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[ ::-1 ].copy() else: prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order ) self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() self.plms_timesteps = self._timesteps[:-3][ ::-1 ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.ets = [] self.counter = 0 self.cur_model_output = 0 def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise), and calls [`~PNDMScheduler.step_prk`] or [`~PNDMScheduler.step_plms`] depending on the internal variable `counter`. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) else: return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) def step_prk( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the Runge-Kutta method. It performs four forward passes to approximate the solution to the differential equation. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 prev_timestep = timestep - diff_to_prev timestep = self.prk_timesteps[self.counter // 4 * 4] if self.counter % 4 == 0: self.cur_model_output += 1 / 6 * model_output self.ets.append(model_output) self.cur_sample = sample elif (self.counter - 1) % 4 == 0: self.cur_model_output += 1 / 3 * model_output elif (self.counter - 2) % 4 == 0: self.cur_model_output += 1 / 3 * model_output elif (self.counter - 3) % 4 == 0: model_output = self.cur_model_output + 1 / 6 * model_output self.cur_model_output = 0 # cur_sample should not be `None` cur_sample = self.cur_sample if self.cur_sample is not None else sample prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) self.counter += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def step_plms( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the linear multistep method. It performs one forward pass multiple times to approximate the solution. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if not self.config.skip_prk_steps and len(self.ets) < 3: raise ValueError( f"{self.__class__} can only be run AFTER scheduler has been run " "in 'prk' mode for at least 12 iterations " "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " "for more information." ) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps if self.counter != 1: self.ets = self.ets[-3:] self.ets.append(model_output) else: prev_timestep = timestep timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps if len(self.ets) == 1 and self.counter == 0: model_output = model_output self.cur_sample = sample elif len(self.ets) == 1 and self.counter == 1: model_output = (model_output + self.ets[-1]) / 2 sample = self.cur_sample self.cur_sample = None elif len(self.ets) == 2: model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 elif len(self.ets) == 3: model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 else: model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) self.counter += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. Returns: `torch.Tensor`: A scaled input sample. """ return sample def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation # Notation ( -> # alpha_prod_t -> α_t # alpha_prod_t_prev -> α_(t−δ) # beta_prod_t -> (1 - α_t) # beta_prod_t_prev -> (1 - α_(t−δ)) # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if self.config.prediction_type == "v_prediction": model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample elif self.config.prediction_type != "epsilon": raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" ) # corresponds to (α_(t−δ) - α_t) divided by # denominator of x_t in formula (9) and plus 1 # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = # sqrt(α_(t−δ)) / sqrt(α_t)) sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) # corresponds to denominator of e_θ(x_t, t) in formula (9) model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( alpha_prod_t * beta_prod_t * alpha_prod_t_prev ) ** (0.5) # full formula (9) prev_sample = ( sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff ) return prev_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_pndm_flax.py ================================================ # Copyright 2024 Zhejiang University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, ) @flax.struct.dataclass class PNDMSchedulerState: common: CommonSchedulerState final_alpha_cumprod: jnp.ndarray # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None prk_timesteps: Optional[jnp.ndarray] = None plms_timesteps: Optional[jnp.ndarray] = None # running values cur_model_output: Optional[jnp.ndarray] = None counter: Optional[jnp.int32] = None cur_sample: Optional[jnp.ndarray] = None ets: Optional[jnp.ndarray] = None @classmethod def create( cls, common: CommonSchedulerState, final_alpha_cumprod: jnp.ndarray, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, ): return cls( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) @dataclass class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput): state: PNDMSchedulerState class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. set_alpha_to_one (`bool`, default `False`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): An offset added to the inference steps, as required by some model families. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype pndm_order: int @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, skip_prk_steps: bool = False, set_alpha_to_one: bool = False, steps_offset: int = 0, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState: if common is None: common = CommonSchedulerState.create(self) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. final_alpha_cumprod = ( jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] ) # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return PNDMSchedulerState.create( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. shape (`Tuple`): the shape of the samples to be generated. """ step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 prk_timesteps = jnp.array([], dtype=jnp.int32) plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1] else: prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile( jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32), self.pndm_order, ) prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1] plms_timesteps = _timesteps[:-3][::-1] timesteps = jnp.concatenate([prk_timesteps, plms_timesteps]) # initial running values cur_model_output = jnp.zeros(shape, dtype=self.dtype) counter = jnp.int32(0) cur_sample = jnp.zeros(shape, dtype=self.dtype) ets = jnp.zeros((4,) + shape, dtype=self.dtype) return state.replace( timesteps=timesteps, num_inference_steps=num_inference_steps, prk_timesteps=prk_timesteps, plms_timesteps=plms_timesteps, cur_model_output=cur_model_output, counter=counter, cur_sample=cur_sample, ets=ets, ) def scale_model_input( self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def step( self, state: PNDMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.config.skip_prk_steps: prev_sample, state = self.step_plms(state, model_output, timestep, sample) else: prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample) plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample) cond = state.counter < len(state.prk_timesteps) prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample) state = state.replace( cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output), ets=jax.lax.select(cond, prk_state.ets, plms_state.ets), cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample), counter=jax.lax.select(cond, prk_state.counter, plms_state.counter), ) if not return_dict: return (prev_sample, state) return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state) def step_prk( self, state: PNDMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) diff_to_prev = jnp.where( state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 ) prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] model_output = jax.lax.select( (state.counter % 4) != 3, model_output, # remainder 0, 1, 2 state.cur_model_output + 1 / 6 * model_output, # remainder 3 ) state = state.replace( cur_model_output=jax.lax.select_n( state.counter % 4, state.cur_model_output + 1 / 6 * model_output, # remainder 0 state.cur_model_output + 1 / 3 * model_output, # remainder 1 state.cur_model_output + 1 / 3 * model_output, # remainder 2 jnp.zeros_like(state.cur_model_output), # remainder 3 ), ets=jax.lax.select( (state.counter % 4) == 0, state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # remainder 0 state.ets, # remainder 1, 2, 3 ), cur_sample=jax.lax.select( (state.counter % 4) == 0, sample, # remainder 0 state.cur_sample, # remainder 1, 2, 3 ), ) cur_sample = state.cur_sample prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) return (prev_sample, state) def step_plms( self, state: PNDMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # NOTE: There is no way to check in the jitted runtime if the prk mode was ran before prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) # Reference: # if state.counter != 1: # state.ets.append(model_output) # else: # prev_timestep = timestep # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) timestep = jnp.where( state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep ) # Reference: # if len(state.ets) == 1 and state.counter == 0: # model_output = model_output # state.cur_sample = sample # elif len(state.ets) == 1 and state.counter == 1: # model_output = (model_output + state.ets[-1]) / 2 # sample = state.cur_sample # state.cur_sample = None # elif len(state.ets) == 2: # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 # elif len(state.ets) == 3: # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 # else: # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]) state = state.replace( ets=jax.lax.select( state.counter != 1, state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter != 1 state.ets, # counter 1 ), cur_sample=jax.lax.select( state.counter != 1, sample, # counter != 1 state.cur_sample, # counter 1 ), ) state = state.replace( cur_model_output=jax.lax.select_n( jnp.clip(state.counter, 0, 4), model_output, # counter 0 (model_output + state.ets[-1]) / 2, # counter 1 (3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2 (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3 (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4 ), ) sample = state.cur_sample model_output = state.cur_model_output prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) return (prev_sample, state) def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation # Notation ( -> # alpha_prod_t -> α_t # alpha_prod_t_prev -> α_(t−δ) # beta_prod_t -> (1 - α_t) # beta_prod_t_prev -> (1 - α_(t−δ)) # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if self.config.prediction_type == "v_prediction": model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample elif self.config.prediction_type != "epsilon": raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" ) # corresponds to (α_(t−δ) - α_t) divided by # denominator of x_t in formula (9) and plus 1 # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = # sqrt(α_(t−δ)) / sqrt(α_t)) sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) # corresponds to denominator of e_θ(x_t, t) in formula (9) model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( alpha_prod_t * beta_prod_t * alpha_prod_t_prev ) ** (0.5) # full formula (9) prev_sample = ( sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff ) return prev_sample def add_noise( self, state: PNDMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_repaint.py ================================================ # Copyright 2024 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor from .scheduling_utils import SchedulerMixin @dataclass class RePaintSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: torch.Tensor # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class RePaintScheduler(SchedulerMixin, ConfigMixin): """ `RePaintScheduler` is a scheduler for DDPM inpainting inside a given mask. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, `squaredcos_cap_v2`, or `sigmoid`. eta (`float`): The weight of noise for added noise in diffusion step. If its value is between 0.0 and 1.0 it corresponds to the DDIM scheduler, and if its value is between -0.0 and 1.0 it corresponds to the DDPM scheduler. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. clip_sample (`bool`, defaults to `True`): Clip the predicted sample between -1 and 1 for numerical stability. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", eta: float = 0.0, trained_betas: Optional[np.ndarray] = None, clip_sample: bool = True, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) elif beta_schedule == "sigmoid": # GeoDiff sigmoid schedule betas = torch.linspace(-6, 6, num_train_timesteps) self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) self.final_alpha_cumprod = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.eta = eta def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample def set_timesteps( self, num_inference_steps: int, jump_length: int = 10, jump_n_sample: int = 10, device: Union[str, torch.device] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. jump_length (`int`, defaults to 10): The number of steps taken forward in time before going backward in time for a single jump (“j” in RePaint paper). Take a look at Figure 9 and 10 in the paper. jump_n_sample (`int`, defaults to 10): The number of times to make a forward time jump for a given chosen time sample. Take a look at Figure 9 and 10 in the paper. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) self.num_inference_steps = num_inference_steps timesteps = [] jumps = {} for j in range(0, num_inference_steps - jump_length, jump_length): jumps[j] = jump_n_sample - 1 t = num_inference_steps while t >= 1: t = t - 1 timesteps.append(t) if jumps.get(t, 0) > 0: jumps[t] = jumps[t] - 1 for _ in range(jump_length): t = t + 1 timesteps.append(t) timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t): prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from # https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get # previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add # variance to pred_sample # Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf # without eta. # variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, original_image: torch.Tensor, mask: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[RePaintSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. original_image (`torch.Tensor`): The original image to inpaint on. mask (`torch.Tensor`): The mask where a value of 0.0 indicates which part of the original image to inpaint. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_repaint.RePaintSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_repaint.RePaintSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_repaint.RePaintSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ t = timestep prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 # 3. Clip "predicted x_0" if self.config.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we # substitute formula (7) in the algorithm coming from DDPM paper # (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper. # DDIM schedule gives the same results as DDPM with eta = 1.0 # Noise is being reused in 7. and 8., but no impact on quality has # been observed. # 5. Add noise device = model_output.device noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype) std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 variance = 0 if t > 0 and self.eta > 0: variance = std_dev_t * noise # 6. compute "direction pointing to x_t" of formula (12) # from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output # 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part if not return_dict: return ( pred_prev_sample, pred_original_sample, ) return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def undo_step(self, sample, timestep, generator=None): n = self.config.num_train_timesteps // self.num_inference_steps for i in range(n): beta = self.betas[timestep + i] if sample.device.type == "mps": # randn does not work reproducibly on mps noise = randn_tensor(sample.shape, dtype=sample.dtype, generator=generator) noise = noise.to(sample.device) else: noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise return sample def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.") def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_sasolver.py ================================================ # Copyright 2024 Shuchen Xue, etc. in University of Chinese Academy of Sciences Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: check https://arxiv.org/abs/2309.05019 # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class SASolverScheduler(SchedulerMixin, ConfigMixin): """ `SASolverScheduler` is a fast dedicated high-order solver for diffusion SDEs. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. predictor_order (`int`, defaults to 2): The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for guided sampling, and `predictor_order=3` for unconditional sampling. corrector_order (`int`, defaults to 2): The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for guided sampling, and `corrector_order=3` for unconditional sampling. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). tau_func (`Callable`, *optional*): Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`. SA-Solver will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample from vanilla diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check https://arxiv.org/abs/2309.05019 thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. algorithm_type (`str`, defaults to `data_prediction`): Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use `data_prediction` with `solver_order=2` for guided sampling like in Stable Diffusion. lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Default = True. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. variance_type (`str`, *optional*): Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output contains the predicted Gaussian variance. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, predictor_order: int = 2, corrector_order: int = 2, prediction_type: str = "epsilon", tau_func: Optional[Callable] = None, thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "data_prediction", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace( beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32, ) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 if algorithm_type not in ["data_prediction", "noise_prediction"]: raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.timestep_list = [None] * max(predictor_order, corrector_order - 1) self.model_outputs = [None] * max(predictor_order, corrector_order - 1) if tau_func is None: self.tau_func = lambda t: 1 if t >= 200 and t <= 800 else 0 else: self.tau_func = tau_func self.predict_x0 = algorithm_type == "data_prediction" self.lower_order_nums = 0 self.last_sample = None self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * max(self.config.predictor_order, self.config.corrector_order - 1) self.lower_order_nums = 0 self.last_sample = None # add an index counter for schedulers that allow duplicated timesteps self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = 1 / ((sigma**2 + 1) ** 0.5) sigma_t = sigma * alpha_t return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def convert_model_output( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: """ Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs. Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is designed to discretize an integral of the data prediction model. The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction for both noise prediction and data prediction models. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The converted model output. """ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError("missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) # SA-Solver_data_prediction needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["data_prediction"]: if self.config.prediction_type == "epsilon": # SA-Solver only needs the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the SASolverScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred # SA-Solver_noise_prediction needs to solve an integral of the noise prediction model. elif self.config.algorithm_type in ["noise_prediction"]: if self.config.prediction_type == "epsilon": # SA-Solver only needs the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: epsilon = model_output[:, :3] else: epsilon = model_output elif self.config.prediction_type == "sample": epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the SASolverScheduler." ) if self.config.thresholding: alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t return epsilon def get_coefficients_exponential_negative(self, order, interval_start, interval_end): """ Calculate the integral of exp(-x) * x^order dx from interval_start to interval_end """ assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" if order == 0: return torch.exp(-interval_end) * (torch.exp(interval_end - interval_start) - 1) elif order == 1: return torch.exp(-interval_end) * ( (interval_start + 1) * torch.exp(interval_end - interval_start) - (interval_end + 1) ) elif order == 2: return torch.exp(-interval_end) * ( (interval_start**2 + 2 * interval_start + 2) * torch.exp(interval_end - interval_start) - (interval_end**2 + 2 * interval_end + 2) ) elif order == 3: return torch.exp(-interval_end) * ( (interval_start**3 + 3 * interval_start**2 + 6 * interval_start + 6) * torch.exp(interval_end - interval_start) - (interval_end**3 + 3 * interval_end**2 + 6 * interval_end + 6) ) def get_coefficients_exponential_positive(self, order, interval_start, interval_end, tau): """ Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end """ assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3" # after change of variable(cov) interval_end_cov = (1 + tau**2) * interval_end interval_start_cov = (1 + tau**2) * interval_start if order == 0: return ( torch.exp(interval_end_cov) * (1 - torch.exp(-(interval_end_cov - interval_start_cov))) / (1 + tau**2) ) elif order == 1: return ( torch.exp(interval_end_cov) * ( (interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(-(interval_end_cov - interval_start_cov)) ) / ((1 + tau**2) ** 2) ) elif order == 2: return ( torch.exp(interval_end_cov) * ( (interval_end_cov**2 - 2 * interval_end_cov + 2) - (interval_start_cov**2 - 2 * interval_start_cov + 2) * torch.exp(-(interval_end_cov - interval_start_cov)) ) / ((1 + tau**2) ** 3) ) elif order == 3: return ( torch.exp(interval_end_cov) * ( (interval_end_cov**3 - 3 * interval_end_cov**2 + 6 * interval_end_cov - 6) - (interval_start_cov**3 - 3 * interval_start_cov**2 + 6 * interval_start_cov - 6) * torch.exp(-(interval_end_cov - interval_start_cov)) ) / ((1 + tau**2) ** 4) ) def lagrange_polynomial_coefficient(self, order, lambda_list): """ Calculate the coefficient of lagrange polynomial """ assert order in [0, 1, 2, 3] assert order == len(lambda_list) - 1 if order == 0: return [[1]] elif order == 1: return [ [ 1 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1]), ], [ 1 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0]), ], ] elif order == 2: denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) return [ [ 1 / denominator1, (-lambda_list[1] - lambda_list[2]) / denominator1, lambda_list[1] * lambda_list[2] / denominator1, ], [ 1 / denominator2, (-lambda_list[0] - lambda_list[2]) / denominator2, lambda_list[0] * lambda_list[2] / denominator2, ], [ 1 / denominator3, (-lambda_list[0] - lambda_list[1]) / denominator3, lambda_list[0] * lambda_list[1] / denominator3, ], ] elif order == 3: denominator1 = ( (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * (lambda_list[0] - lambda_list[3]) ) denominator2 = ( (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * (lambda_list[1] - lambda_list[3]) ) denominator3 = ( (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * (lambda_list[2] - lambda_list[3]) ) denominator4 = ( (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * (lambda_list[3] - lambda_list[2]) ) return [ [ 1 / denominator1, (-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1, ( lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[3] ) / denominator1, (-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1, ], [ 1 / denominator2, (-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2, ( lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[3] ) / denominator2, (-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2, ], [ 1 / denominator3, (-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3, ( lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[3] ) / denominator3, (-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3, ], [ 1 / denominator4, (-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4, ( lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[2] ) / denominator4, (-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4, ], ] def get_coefficients_fn(self, order, interval_start, interval_end, lambda_list, tau): assert order in [1, 2, 3, 4] assert order == len(lambda_list), "the length of lambda list must be equal to the order" coefficients = [] lagrange_coefficient = self.lagrange_polynomial_coefficient(order - 1, lambda_list) for i in range(order): coefficient = 0 for j in range(order): if self.predict_x0: coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_positive( order - 1 - j, interval_start, interval_end, tau ) else: coefficient += lagrange_coefficient[i][j] * self.get_coefficients_exponential_negative( order - 1 - j, interval_start, interval_end ) coefficients.append(coefficient) assert len(coefficients) == order, "the length of coefficients does not match the order" return coefficients def stochastic_adams_bashforth_update( self, model_output: torch.Tensor, *args, sample: torch.Tensor, noise: torch.Tensor, order: int, tau: torch.Tensor, **kwargs, ) -> torch.Tensor: """ One step for the SA-Predictor. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model at the current timestep. prev_timestep (`int`): The previous discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. order (`int`): The order of SA-Predictor at this timestep. Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError(" missing `sample` as a required keyward argument") if noise is None: if len(args) > 2: noise = args[2] else: raise ValueError(" missing `noise` as a required keyward argument") if order is None: if len(args) > 3: order = args[3] else: raise ValueError(" missing `order` as a required keyward argument") if tau is None: if len(args) > 4: tau = args[4] else: raise ValueError(" missing `tau` as a required keyward argument") if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) model_output_list = self.model_outputs sigma_t, sigma_s0 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) gradient_part = torch.zeros_like(sample) h = lambda_t - lambda_s0 lambda_list = [] for i in range(order): si = self.step_index - i alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) lambda_list.append(lambda_si) gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau) x = sample if self.predict_x0: if ( order == 2 ): ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling. # The added term is O(h^3). Empirically we find it will slightly improve the image quality. # ODE case # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2])) temp_sigma = self.sigmas[self.step_index - 1] temp_alpha_s, temp_sigma_s = self._sigma_to_alpha_sigma_t(temp_sigma) temp_lambda_s = torch.log(temp_alpha_s) - torch.log(temp_sigma_s) gradient_coefficients[0] += ( 1.0 * torch.exp((1 + tau**2) * lambda_t) * (h**2 / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2)) / (lambda_s0 - temp_lambda_s) ) gradient_coefficients[1] -= ( 1.0 * torch.exp((1 + tau**2) * lambda_t) * (h**2 / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2)) / (lambda_s0 - temp_lambda_s) ) for i in range(order): if self.predict_x0: gradient_part += ( (1 + tau**2) * sigma_t * torch.exp(-(tau**2) * lambda_t) * gradient_coefficients[i] * model_output_list[-(i + 1)] ) else: gradient_part += -(1 + tau**2) * alpha_t * gradient_coefficients[i] * model_output_list[-(i + 1)] if self.predict_x0: noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau**2 * h)) * noise else: noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * noise if self.predict_x0: x_t = torch.exp(-(tau**2) * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part else: x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part x_t = x_t.to(x.dtype) return x_t def stochastic_adams_moulton_update( self, this_model_output: torch.Tensor, *args, last_sample: torch.Tensor, last_noise: torch.Tensor, this_sample: torch.Tensor, order: int, tau: torch.Tensor, **kwargs, ) -> torch.Tensor: """ One step for the SA-Corrector. Args: this_model_output (`torch.Tensor`): The model outputs at `x_t`. this_timestep (`int`): The current timestep `t`. last_sample (`torch.Tensor`): The generated sample before the last predictor `x_{t-1}`. this_sample (`torch.Tensor`): The generated sample after the last predictor `x_{t}`. order (`int`): The order of SA-Corrector at this step. Returns: `torch.Tensor`: The corrected sample tensor at the current timestep. """ this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) if last_sample is None: if len(args) > 1: last_sample = args[1] else: raise ValueError(" missing`last_sample` as a required keyward argument") if last_noise is None: if len(args) > 2: last_noise = args[2] else: raise ValueError(" missing`last_noise` as a required keyward argument") if this_sample is None: if len(args) > 3: this_sample = args[3] else: raise ValueError(" missing`this_sample` as a required keyward argument") if order is None: if len(args) > 4: order = args[4] else: raise ValueError(" missing`order` as a required keyward argument") if tau is None: if len(args) > 5: tau = args[5] else: raise ValueError(" missing`tau` as a required keyward argument") if this_timestep is not None: deprecate( "this_timestep", "1.0.0", "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) model_output_list = self.model_outputs sigma_t, sigma_s0 = ( self.sigmas[self.step_index], self.sigmas[self.step_index - 1], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) gradient_part = torch.zeros_like(this_sample) h = lambda_t - lambda_s0 lambda_list = [] for i in range(order): si = self.step_index - i alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) lambda_list.append(lambda_si) model_prev_list = model_output_list + [this_model_output] gradient_coefficients = self.get_coefficients_fn(order, lambda_s0, lambda_t, lambda_list, tau) x = last_sample if self.predict_x0: if ( order == 2 ): ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling. # The added term is O(h^3). Empirically we find it will slightly improve the image quality. # ODE case # gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) # gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h) gradient_coefficients[0] += ( 1.0 * torch.exp((1 + tau**2) * lambda_t) * (h / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2 * h)) ) gradient_coefficients[1] -= ( 1.0 * torch.exp((1 + tau**2) * lambda_t) * (h / 2 - (h * (1 + tau**2) - 1 + torch.exp((1 + tau**2) * (-h))) / ((1 + tau**2) ** 2 * h)) ) for i in range(order): if self.predict_x0: gradient_part += ( (1 + tau**2) * sigma_t * torch.exp(-(tau**2) * lambda_t) * gradient_coefficients[i] * model_prev_list[-(i + 1)] ) else: gradient_part += -(1 + tau**2) * alpha_t * gradient_coefficients[i] * model_prev_list[-(i + 1)] if self.predict_x0: noise_part = sigma_t * torch.sqrt(1 - torch.exp(-2 * tau**2 * h)) * last_noise else: noise_part = tau * sigma_t * torch.sqrt(torch.exp(2 * h) - 1) * last_noise if self.predict_x0: x_t = torch.exp(-(tau**2) * h) * (sigma_t / sigma_s0) * x + gradient_part + noise_part else: x_t = (alpha_t / alpha_s0) * x + gradient_part + noise_part x_t = x_t.to(x.dtype) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) elif len(index_candidates) > 1: step_index = index_candidates[1].item() else: step_index = index_candidates[0].item() return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the SA-Solver. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) use_corrector = self.step_index > 0 and self.last_sample is not None model_output_convert = self.convert_model_output(model_output, sample=sample) if use_corrector: current_tau = self.tau_func(self.timestep_list[-1]) sample = self.stochastic_adams_moulton_update( this_model_output=model_output_convert, last_sample=self.last_sample, last_noise=self.last_noise, this_sample=sample, order=self.this_corrector_order, tau=current_tau, ) for i in range(max(self.config.predictor_order, self.config.corrector_order - 1) - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] self.model_outputs[-1] = model_output_convert self.timestep_list[-1] = timestep noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype, ) if self.config.lower_order_final: this_predictor_order = min(self.config.predictor_order, len(self.timesteps) - self.step_index) this_corrector_order = min(self.config.corrector_order, len(self.timesteps) - self.step_index + 1) else: this_predictor_order = self.config.predictor_order this_corrector_order = self.config.corrector_order self.this_predictor_order = min(this_predictor_order, self.lower_order_nums + 1) # warmup for multistep self.this_corrector_order = min(this_corrector_order, self.lower_order_nums + 2) # warmup for multistep assert self.this_predictor_order > 0 assert self.this_corrector_order > 0 self.last_sample = sample self.last_noise = noise current_tau = self.tau_func(self.timestep_list[-1]) prev_sample = self.stochastic_adams_bashforth_update( model_output=model_output_convert, sample=sample, noise=noise, order=self.this_predictor_order, tau=current_tau, ) if self.lower_order_nums < max(self.config.predictor_order, self.config.corrector_order - 1): self.lower_order_nums += 1 # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. Returns: `torch.Tensor`: A scaled input sample. """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_sde_ve.py ================================================ # Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch import math from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor from .scheduling_utils import SchedulerMixin, SchedulerOutput @dataclass class SdeVeOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. prev_sample_mean (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Mean averaged `prev_sample` over previous timesteps. """ prev_sample: torch.Tensor prev_sample_mean: torch.Tensor class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): """ `ScoreSdeVeScheduler` is a variance exploding stochastic differential equation (SDE) scheduler. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. snr (`float`, defaults to 0.15): A coefficient weighting the step from the `model_output` sample (from the network) to the random noise. sigma_min (`float`, defaults to 0.01): The initial noise scale for the sigma sequence in the sampling procedure. The minimum sigma should mirror the distribution of the data. sigma_max (`float`, defaults to 1348.0): The maximum value used for the range of continuous timesteps passed into the model. sampling_eps (`float`, defaults to 1e-5): The end value of sampling where timesteps decrease progressively from 1 to epsilon. correct_steps (`int`, defaults to 1): The number of correction steps performed on a produced sample. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 2000, snr: float = 0.15, sigma_min: float = 0.01, sigma_max: float = 1348.0, sampling_eps: float = 1e-5, correct_steps: int = 1, ): # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max # setable values self.timesteps = None self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample def set_timesteps( self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None ): """ Sets the continuous timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. sampling_eps (`float`, *optional*): The final timestep value (overrides value given during scheduler instantiation). device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device) def set_sigmas( self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None ): """ Sets the noise scales used for the diffusion chain (to be run before inference). The sigmas control the weight of the `drift` and `diffusion` components of the sample update. Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. sigma_min (`float`, optional): The initial noise scale value (overrides value given during scheduler instantiation). sigma_max (`float`, optional): The final noise scale value (overrides value given during scheduler instantiation). sampling_eps (`float`, optional): The final timestep value (overrides value given during scheduler instantiation). """ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps if self.timesteps is None: self.set_timesteps(num_inference_steps, sampling_eps) self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps) self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps)) self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) def get_adjacent_sigma(self, timesteps, t): return torch.where( timesteps == 0, torch.zeros_like(t.to(timesteps.device)), self.discrete_sigmas[timesteps - 1].to(timesteps.device), ) def step_pred( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SdeVeOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`. Returns: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.timesteps is None: raise ValueError( "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) timestep = timestep * torch.ones( sample.shape[0], device=sample.device ) # torch.repeat_interleave(timestep, sample.shape[0]) timesteps = (timestep * (len(self.timesteps) - 1)).long() # mps requires indices to be in the same device, so we use cpu as is the default with cuda timesteps = timesteps.to(self.discrete_sigmas.device) sigma = self.discrete_sigmas[timesteps].to(sample.device) adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) drift = torch.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods diffusion = diffusion.flatten() while len(diffusion.shape) < len(sample.shape): diffusion = diffusion.unsqueeze(-1) drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of noise = randn_tensor( sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype ) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g if not return_dict: return (prev_sample, prev_sample_mean) return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean) def step_correct( self, model_output: torch.Tensor, sample: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Correct the predicted sample based on the `model_output` of the network. This is often run repeatedly after making the prediction for the previous timestep. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`. Returns: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.timesteps is None: raise ValueError( "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # sample noise for correction noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device) # compute step size from the model_output, the noise, and the snr grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) # self.repeat_scalar(step_size, sample.shape[0]) # compute corrected sample: model_output term and noise term step_size = step_size.flatten() while len(step_size.shape) < len(sample.shape): step_size = step_size.unsqueeze(-1) prev_sample_mean = sample + step_size * model_output prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples timesteps = timesteps.to(original_samples.device) sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps] noise = ( noise * sigmas[:, None, None, None] if noise is not None else torch.randn_like(original_samples) * sigmas[:, None, None, None] ) noisy_samples = noise + original_samples return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_sde_ve_flax.py ================================================ # Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax import jax.numpy as jnp from jax import random from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left @flax.struct.dataclass class ScoreSdeVeSchedulerState: # setable values timesteps: Optional[jnp.ndarray] = None discrete_sigmas: Optional[jnp.ndarray] = None sigmas: Optional[jnp.ndarray] = None @classmethod def create(cls): return cls() @dataclass class FlaxSdeVeOutput(FlaxSchedulerOutput): """ Output class for the ScoreSdeVeScheduler's step function output. Args: state (`ScoreSdeVeSchedulerState`): prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. """ state: ScoreSdeVeSchedulerState prev_sample: jnp.ndarray prev_sample_mean: Optional[jnp.ndarray] = None class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. For more information, see the original paper: https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. snr (`float`): coefficient weighting the step from the model_output sample (from the network) to the random noise. sigma_min (`float`): initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the distribution of the data. sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to epsilon. correct_steps (`int`): number of correction steps performed on a produced sample. """ @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 2000, snr: float = 0.15, sigma_min: float = 0.01, sigma_max: float = 1348.0, sampling_eps: float = 1e-5, correct_steps: int = 1, ): pass def create_state(self): state = ScoreSdeVeSchedulerState.create() return self.set_sigmas( state, self.config.num_train_timesteps, self.config.sigma_min, self.config.sigma_max, self.config.sampling_eps, ) def set_timesteps( self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None ) -> ScoreSdeVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). """ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps timesteps = jnp.linspace(1, sampling_eps, num_inference_steps) return state.replace(timesteps=timesteps) def set_sigmas( self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None, ) -> ScoreSdeVeSchedulerState: """ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. The sigmas control the weight of the `drift` and `diffusion` components of sample update. Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. sigma_min (`float`, optional): initial noise scale value (overrides value given at Scheduler instantiation). sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). """ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps if state.timesteps is None: state = self.set_timesteps(state, num_inference_steps, sampling_eps) discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps)) sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps]) return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas) def get_adjacent_sigma(self, state, timesteps, t): return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1]) def step_pred( self, state: ScoreSdeVeSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, key: jax.Array, return_dict: bool = True, ) -> Union[FlaxSdeVeOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class Returns: [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.timesteps is None: raise ValueError( "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) timestep = timestep * jnp.ones( sample.shape[0], ) timesteps = (timestep * (len(state.timesteps) - 1)).long() sigma = state.discrete_sigmas[timesteps] adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep) drift = jnp.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods diffusion = diffusion.flatten() diffusion = broadcast_to_shape_from_left(diffusion, sample.shape) drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of key = random.split(key, num=1) noise = random.normal(key=key, shape=sample.shape) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g if not return_dict: return (prev_sample, prev_sample_mean, state) return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state) def step_correct( self, state: ScoreSdeVeSchedulerState, model_output: jnp.ndarray, sample: jnp.ndarray, key: jax.Array, return_dict: bool = True, ) -> Union[FlaxSdeVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly after making the prediction for the previous timestep. Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class Returns: [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.timesteps is None: raise ValueError( "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # sample noise for correction key = random.split(key, num=1) noise = random.normal(key=key, shape=sample.shape) # compute step size from the model_output, the noise, and the snr grad_norm = jnp.linalg.norm(model_output) noise_norm = jnp.linalg.norm(noise) step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = step_size * jnp.ones(sample.shape[0]) # compute corrected sample: model_output term and noise term step_size = step_size.flatten() step_size = broadcast_to_shape_from_left(step_size, sample.shape) prev_sample_mean = sample + step_size * model_output prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise if not return_dict: return (prev_sample, state) return FlaxSdeVeOutput(prev_sample=prev_sample, state=state) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_tcd.py ================================================ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..schedulers.scheduling_utils import SchedulerMixin from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class TCDSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_noised_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted noised sample `(x_{s})` based on the model output from the current timestep. """ prev_sample: torch.Tensor pred_noised_sample: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class TCDScheduler(SchedulerMixin, ConfigMixin): """ `TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency Distillation`, extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal. This code is based on the official repo of TCD(https://github.com/jabir-zheng/TCD). This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. original_inference_steps (`int`, *optional*, defaults to 50): The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule. clip_sample (`bool`, defaults to `True`): Clip the predicted sample for numerical stability. clip_sample_range (`float`, defaults to 1.0): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, defaults to `True`): Each diffusion step uses the alphas product value at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the alpha value at step 0. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. timestep_scaling (`float`, defaults to 10.0): The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation error at the default of `10.0` is already pretty small). rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, beta_end: float = 0.012, beta_schedule: str = "scaled_linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, original_inference_steps: int = 50, clip_sample: bool = False, clip_sample_range: float = 1.0, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, timestep_spacing: str = "leading", timestep_scaling: float = 10.0, rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) self.custom_timesteps = False self._step_index = None self._begin_index = None # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index @property def step_index(self): return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, original_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, strength: float = 1.0, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. original_inference_steps (`int`, *optional*): The original number of inference steps, which will be used to generate a linearly-spaced timestep schedule (which is different from the standard `diffusers` implementation). We will then take `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. strength (`float`, *optional*, defaults to 1.0): Used to determine the number of timesteps used for inference when using img2img, inpaint, etc. """ # 0. Check inputs if num_inference_steps is None and timesteps is None: raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") # 1. Calculate the TCD original training/distillation timestep schedule. original_steps = ( original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps ) if original_inference_steps is None: # default option, timesteps align with discrete inference steps if original_steps > self.config.num_train_timesteps: raise ValueError( f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) # TCD Timesteps Setting # The skipping step parameter k from the paper. k = self.config.num_train_timesteps // original_steps # TCD Training/Distillation Steps Schedule tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1 else: # customised option, sampled timesteps can be any arbitrary value tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps * strength)))) # 2. Calculate the TCD inference timestep schedule. if timesteps is not None: # 2.1 Handle custom timestep schedules. train_timesteps = set(tcd_origin_timesteps) non_train_timesteps = [] for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`custom_timesteps` must be in descending order.") if timesteps[i] not in train_timesteps: non_train_timesteps.append(timesteps[i]) if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.config.num_train_timesteps}." ) # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1 if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1: logger.warning( f"The first timestep on the custom timestep schedule is {timesteps[0]}, not" f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get" f" unexpected results when using this timestep schedule." ) # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule if non_train_timesteps: logger.warning( f"The custom timestep schedule contains the following timesteps which are not on the original" f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results" f" when using this timestep schedule." ) # Raise warning if custom timestep schedule is longer than original_steps if original_steps is not None: if len(timesteps) > original_steps: logger.warning( f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the" f" the length of the timestep schedule used for training: {original_steps}. You may get some" f" unexpected results when using this timestep schedule." ) else: if len(timesteps) > self.config.num_train_timesteps: logger.warning( f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the" f" the length of the timestep schedule used for training: {self.config.num_train_timesteps}. You may get some" f" unexpected results when using this timestep schedule." ) timesteps = np.array(timesteps, dtype=np.int64) self.num_inference_steps = len(timesteps) self.custom_timesteps = True # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps) init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps) t_start = max(self.num_inference_steps - init_timestep, 0) timesteps = timesteps[t_start * self.order :] # TODO: also reset self.num_inference_steps? else: # 2.2 Create the "standard" TCD inference timestep schedule. if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) if original_steps is not None: skipping_step = len(tcd_origin_timesteps) // num_inference_steps if skipping_step < 1: raise ValueError( f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}." ) self.num_inference_steps = num_inference_steps if original_steps is not None: if num_inference_steps > original_steps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:" f" {original_steps} because the final timestep schedule will be a subset of the" f" `original_inference_steps`-sized initial timestep schedule." ) else: if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `num_train_timesteps`:" f" {self.config.num_train_timesteps} because the final timestep schedule will be a subset of the" f" `num_train_timesteps`-sized initial timestep schedule." ) # TCD Inference Steps Schedule tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy() # Select (approximately) evenly spaced indices from tcd_origin_timesteps. inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False) inference_indices = np.floor(inference_indices).astype(np.int64) timesteps = tcd_origin_timesteps[inference_indices] self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long) self._step_index = None self._begin_index = None def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, eta: float = 0.3, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[TCDSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. eta (`float`): A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every step. When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.TCDSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) assert 0 <= eta <= 1.0, "gamma must be less than or equal to 1.0" # 1. get previous step value prev_step_index = self.step_index + 1 if prev_step_index < len(self.timesteps): prev_timestep = self.timesteps[prev_step_index] else: prev_timestep = torch.tensor(0) timestep_s = torch.floor((1 - eta) * prev_timestep).to(dtype=torch.long) # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod alpha_prod_s = self.alphas_cumprod[timestep_s] beta_prod_s = 1 - alpha_prod_s # 3. Compute the predicted noised sample x_s based on the model parameterization if self.config.prediction_type == "epsilon": # noise-prediction pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() pred_epsilon = model_output pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon elif self.config.prediction_type == "sample": # x-prediction pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon elif self.config.prediction_type == "v_prediction": # v-prediction pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for `TCDScheduler`." ) # 4. Sample and inject noise z ~ N(0, I) for MultiStep Inference # Noise is not used on the final timestep of the timestep schedule. # This also means that noise is not used for one-step sampling. # Eta (referred to as "gamma" in the paper) was introduced to control the stochasticity in every step. # When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling. if eta > 0: if self.step_index != self.num_inference_steps - 1: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=pred_noised_sample.dtype ) prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + ( 1 - alpha_prod_t_prev / alpha_prod_s ).sqrt() * noise else: prev_sample = pred_noised_sample else: prev_sample = pred_noised_sample # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample, pred_noised_sample) return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): if self.custom_timesteps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: num_inference_steps = ( self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps ) prev_t = timestep - self.config.num_train_timesteps // num_inference_steps return prev_t ================================================ FILE: src/diffusers/schedulers/scheduling_unclip.py ================================================ # Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from ..utils.torch_utils import randn_tensor from .scheduling_utils import SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP class UnCLIPSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.Tensor pred_original_sample: Optional[torch.Tensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class UnCLIPScheduler(SchedulerMixin, ConfigMixin): """ NOTE: do not use this scheduler. The DDPM scheduler has been updated to support the changes made here. This scheduler will be removed and replaced with DDPM. This is a modified DDPM Scheduler specifically for the karlo unCLIP model. This scheduler has some minor variations in how it calculates the learned range variance and dynamically re-calculates betas based off the timesteps it is skipping. The scheduler also uses a slightly different step ratio when computing timesteps to use for inference. See [`~DDPMScheduler`] for more information on DDPM scheduling Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small_log` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between `-clip_sample_range` and `clip_sample_range` for numerical stability. clip_sample_range (`float`, default `1.0`): The range to clip the sample between. See `clip_sample`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) or `sample` (directly predicting the noisy sample`) """ @register_to_config def __init__( self, num_train_timesteps: int = 1000, variance_type: str = "fixed_small_log", clip_sample: bool = True, clip_sample_range: Optional[float] = 1.0, prediction_type: str = "epsilon", beta_schedule: str = "squaredcos_cap_v2", ): if beta_schedule != "squaredcos_cap_v2": raise ValueError("UnCLIPScheduler only supports `beta_schedule`: 'squaredcos_cap_v2'") self.betas = betas_for_alpha_bar(num_train_timesteps) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.Tensor`: scaled input sample """ return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Note that this scheduler uses a slightly different step ratio than the other diffusers schedulers. The different step ratio is to mimic the original karlo implementation and does not affect the quality or accuracy of the results. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps step_ratio = (self.config.num_train_timesteps - 1) / (self.num_inference_steps - 1) timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, prev_timestep=None, predicted_variance=None, variance_type=None): if prev_timestep is None: prev_timestep = t - 1 alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if prev_timestep == t - 1: beta = self.betas[t] else: beta = 1 - alpha_prod_t / alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = beta_prod_t_prev / beta_prod_t * beta if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small_log": variance = torch.log(torch.clamp(variance, min=1e-20)) variance = torch.exp(0.5 * variance) elif variance_type == "learned_range": # NOTE difference with DDPM scheduler min_log = variance.log() max_log = beta.log() frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, prev_timestep: Optional[int] = None, generator=None, return_dict: bool = True, ) -> Union[UnCLIPSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.Tensor`): current instance of sample being created by diffusion process. prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at. Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used. generator: random number generator. return_dict (`bool`): option for returning tuple rather than UnCLIPSchedulerOutput class Returns: [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range": model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas if prev_timestep is None: prev_timestep = t - 1 alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if prev_timestep == t - 1: beta = self.betas[t] alpha = self.alphas[t] else: beta = 1 - alpha_prod_t / alpha_prod_t_prev alpha = 1 - beta # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `sample`" " for the UnCLIPScheduler." ) # 3. Clip "predicted x_0" if self.config.clip_sample: pred_original_sample = torch.clamp( pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise variance = 0 if t > 0: variance_noise = randn_tensor( model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device ) variance = self._get_variance( t, predicted_variance=predicted_variance, prev_timestep=prev_timestep, ) if self.variance_type == "fixed_small_log": variance = variance elif self.variance_type == "learned_range": variance = (0.5 * variance).exp() else: raise ValueError( f"variance_type given as {self.variance_type} must be one of `fixed_small_log` or `learned_range`" " for the UnCLIPScheduler." ) variance = variance * variance_noise pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample,) return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples ================================================ FILE: src/diffusers/schedulers/scheduling_unipc_multistep.py ================================================ # Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. solver_order (`int`, default `2`): The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. predict_x0 (`bool`, defaults to `True`): Whether to use the updating algorithm on the predicted x0. solver_type (`str`, default `bh2`): Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` otherwise. lower_order_final (`bool`, default `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. disable_corrector (`list`, default `[]`): Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is usually disabled during the first few steps. solver_p (`SchedulerMixin`, default `None`): Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. final_sigmas_type (`str`, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, predict_x0: bool = True, solver_type: str = "bh2", lower_order_final: bool = True, disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, use_karras_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) if rescale_betas_zero_snr: # Close to 0 without being 0 so first sigma is not inf # FP16 smallest positive subnormal works well here self.alphas_cumprod[-1] = 2**-24 # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: self.register_to_config(solver_type="bh2") else: raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") self.predict_x0 = predict_x0 # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.timestep_list = [None] * solver_order self.lower_order_nums = 0 self.disable_corrector = disable_corrector self.solver_p = solver_p self.last_sample = None self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() if self.config.final_sigmas_type == "sigma_min": sigma_last = sigmas[-1] elif self.config.final_sigmas_type == "zero": sigma_last = 0 else: raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 elif self.config.final_sigmas_type == "zero": sigma_last = 0 else: raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 self.last_sample = None if self.solver_p: self.solver_p.set_timesteps(self.num_inference_steps, device=device) # add an index counter for schedulers that allow duplicated timesteps self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = 1 / ((sigma**2 + 1) ** 0.5) sigma_t = sigma * alpha_t return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def convert_model_output( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: r""" Convert the model output to the corresponding type the UniPC algorithm needs. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The converted model output. """ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError("missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) if self.predict_x0: if self.config.prediction_type == "epsilon": x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the UniPCMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred else: if self.config.prediction_type == "epsilon": return model_output elif self.config.prediction_type == "sample": epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the UniPCMultistepScheduler." ) def multistep_uni_p_bh_update( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, order: int = None, **kwargs, ) -> torch.Tensor: """ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model at the current timestep. prev_timestep (`int`): The previous discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. order (`int`): The order of UniP at this timestep (corresponds to the *p* in UniPC-p). Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError(" missing `sample` as a required keyward argument") if order is None: if len(args) > 2: order = args[2] else: raise ValueError(" missing `order` as a required keyward argument") if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) model_output_list = self.model_outputs s0 = self.timestep_list[-1] m0 = model_output_list[-1] x = sample if self.solver_p: x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 device = sample.device rks = [] D1s = [] for i in range(1, order): si = self.step_index - i mi = model_output_list[-(i + 1)] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) # for order 2, we use a simplified version if order == 2: rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) else: D1s = None if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - alpha_t * B_h * pred_res else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - sigma_t * B_h * pred_res x_t = x_t.to(x.dtype) return x_t def multistep_uni_c_bh_update( self, this_model_output: torch.Tensor, *args, last_sample: torch.Tensor = None, this_sample: torch.Tensor = None, order: int = None, **kwargs, ) -> torch.Tensor: """ One step for the UniC (B(h) version). Args: this_model_output (`torch.Tensor`): The model outputs at `x_t`. this_timestep (`int`): The current timestep `t`. last_sample (`torch.Tensor`): The generated sample before the last predictor `x_{t-1}`. this_sample (`torch.Tensor`): The generated sample after the last predictor `x_{t}`. order (`int`): The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. Returns: `torch.Tensor`: The corrected sample tensor at the current timestep. """ this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) if last_sample is None: if len(args) > 1: last_sample = args[1] else: raise ValueError(" missing`last_sample` as a required keyward argument") if this_sample is None: if len(args) > 2: this_sample = args[2] else: raise ValueError(" missing`this_sample` as a required keyward argument") if order is None: if len(args) > 3: order = args[3] else: raise ValueError(" missing`order` as a required keyward argument") if this_timestep is not None: deprecate( "this_timestep", "1.0.0", "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) model_output_list = self.model_outputs m0 = model_output_list[-1] x = last_sample x_t = this_sample model_t = this_model_output sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 device = this_sample.device rks = [] D1s = [] for i in range(1, order): si = self.step_index - (i + 1) mi = model_output_list[-(i + 1)] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) else: D1s = None # for order 1, we use a simplified version if order == 1: rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) x_t = x_t.to(x.dtype) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) elif len(index_candidates) > 1: step_index = index_candidates[1].item() else: step_index = index_candidates[0].item() return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the multistep UniPC. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) use_corrector = ( self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None ) model_output_convert = self.convert_model_output(model_output, sample=sample) if use_corrector: sample = self.multistep_uni_c_bh_update( this_model_output=model_output_convert, last_sample=self.last_sample, this_sample=sample, order=self.this_order, ) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] self.model_outputs[-1] = model_output_convert self.timestep_list[-1] = timestep if self.config.lower_order_final: this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) else: this_order = self.config.solver_order self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep assert self.this_order > 0 self.last_sample = sample prev_sample = self.multistep_uni_p_bh_update( model_output=model_output, # pass the original non-converted model output, in case solver-p is used sample=sample, order=self.this_order, ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. Returns: `torch.Tensor`: A scaled input sample. """ return sample # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: src/diffusers/schedulers/scheduling_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import os from dataclasses import dataclass from enum import Enum from typing import Optional, Union import torch from huggingface_hub.utils import validate_hf_hub_args from ..utils import BaseOutput, PushToHubMixin SCHEDULER_CONFIG_NAME = "scheduler_config.json" # NOTE: We make this type an enum because it simplifies usage in docs and prevents # circular imports when used for `_compatibles` within the schedulers module. # When it's used as a type in pipelines, it really is a Union because the actual # scheduler instance is passed in. class KarrasDiffusionSchedulers(Enum): DDIMScheduler = 1 DDPMScheduler = 2 PNDMScheduler = 3 LMSDiscreteScheduler = 4 EulerDiscreteScheduler = 5 HeunDiscreteScheduler = 6 EulerAncestralDiscreteScheduler = 7 DPMSolverMultistepScheduler = 8 DPMSolverSinglestepScheduler = 9 KDPM2DiscreteScheduler = 10 KDPM2AncestralDiscreteScheduler = 11 DEISMultistepScheduler = 12 UniPCMultistepScheduler = 13 DPMSolverSDEScheduler = 14 EDMEulerScheduler = 15 AysSchedules = { "StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24], "StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0], "StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13], "StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0], "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0], } @dataclass class SchedulerOutput(BaseOutput): """ Base class for the output of a scheduler's `step` function. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.Tensor class SchedulerMixin(PushToHubMixin): """ Base class for all schedulers. [`SchedulerMixin`] contains common functions shared by all schedulers such as general loading and saving functionalities. [`ConfigMixin`] takes care of storing the configuration attributes (like `num_train_timesteps`) that are passed to the scheduler's `__init__` function, and the attributes can be accessed by `scheduler.config.num_train_timesteps`. Class attributes: - **_compatibles** (`List[str]`) -- A list of scheduler classes that are compatible with the parent scheduler class. Use [`~ConfigMixin.from_config`] to load a different compatible scheduler class (should be overridden by parent class). """ config_name = SCHEDULER_CONFIG_NAME _compatibles = [] has_compatibles = True @classmethod @validate_hf_hub_args def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, subfolder: Optional[str] = None, return_unused_kwargs=False, **kwargs, ): r""" Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the scheduler configuration saved with [`~SchedulerMixin.save_pretrained`]. subfolder (`str`, *optional*): The subfolder location of a model file within a larger model repository on the Hub or locally. return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether kwargs that are not consumed by the Python class should be returned or not. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. You can also activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a firewalled environment. """ config, kwargs, commit_hash = cls.load_config( pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, return_unused_kwargs=True, return_commit_hash=True, **kwargs, ) return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a scheduler configuration object to a directory so that it can be reloaded using the [`~SchedulerMixin.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) @property def compatibles(self): """ Returns all schedulers that are compatible with this scheduler Returns: `List[SchedulerMixin]`: List of compatible schedulers """ return self._get_compatibles() @classmethod def _get_compatibles(cls): compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) diffusers_library = importlib.import_module(__name__.split(".")[0]) compatible_classes = [ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes ================================================ FILE: src/diffusers/schedulers/scheduling_utils_flax.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import math import os from dataclasses import dataclass from enum import Enum from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from huggingface_hub.utils import validate_hf_hub_args from ..utils import BaseOutput, PushToHubMixin SCHEDULER_CONFIG_NAME = "scheduler_config.json" # NOTE: We make this type an enum because it simplifies usage in docs and prevents # circular imports when used for `_compatibles` within the schedulers module. # When it's used as a type in pipelines, it really is a Union because the actual # scheduler instance is passed in. class FlaxKarrasDiffusionSchedulers(Enum): FlaxDDIMScheduler = 1 FlaxDDPMScheduler = 2 FlaxPNDMScheduler = 3 FlaxLMSDiscreteScheduler = 4 FlaxDPMSolverMultistepScheduler = 5 FlaxEulerDiscreteScheduler = 6 @dataclass class FlaxSchedulerOutput(BaseOutput): """ Base class for the scheduler's step function output. Args: prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: jnp.ndarray class FlaxSchedulerMixin(PushToHubMixin): """ Mixin containing common functions for the schedulers. Class attributes: - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that `from_config` can be used from a class different than the one used to save the config (should be overridden by parent class). """ config_name = SCHEDULER_CONFIG_NAME ignore_for_config = ["dtype"] _compatibles = [] has_compatibles = True @classmethod @validate_hf_hub_args def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, subfolder: Optional[str] = None, return_unused_kwargs=False, **kwargs, ): r""" Instantiate a Scheduler class from a pre-defined JSON-file. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. subfolder (`str`, *optional*): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether kwargs that are not consumed by the Python class should be returned or not. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a firewalled environment. """ config, kwargs = cls.load_config( pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, return_unused_kwargs=True, **kwargs, ) scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs) if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False): state = scheduler.create_state() if return_unused_kwargs: return scheduler, state, unused_kwargs return scheduler, state def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the [`~FlaxSchedulerMixin.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) @property def compatibles(self): """ Returns all schedulers that are compatible with this scheduler Returns: `List[SchedulerMixin]`: List of compatible schedulers """ return self._get_compatibles() @classmethod def _get_compatibles(cls): compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) diffusers_library = importlib.import_module(__name__.split(".")[0]) compatible_classes = [ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: assert len(shape) >= x.ndim return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape) def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. Returns: betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return jnp.array(betas, dtype=dtype) @flax.struct.dataclass class CommonSchedulerState: alphas: jnp.ndarray betas: jnp.ndarray alphas_cumprod: jnp.ndarray @classmethod def create(cls, scheduler): config = scheduler.config if config.trained_betas is not None: betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype) elif config.beta_schedule == "linear": betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype) elif config.beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. betas = ( jnp.linspace( config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype ) ** 2 ) elif config.beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype) else: raise NotImplementedError( f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}" ) alphas = 1.0 - betas alphas_cumprod = jnp.cumprod(alphas, axis=0) return cls( alphas=alphas, betas=betas, alphas_cumprod=alphas_cumprod, ) def get_sqrt_alpha_prod( state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray ): alphas_cumprod = state.alphas_cumprod sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) return sqrt_alpha_prod, sqrt_one_minus_alpha_prod def add_noise_common( state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray ): sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray): sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity ================================================ FILE: src/diffusers/schedulers/scheduling_vq_diffusion.py ================================================ # Copyright 2024 Microsoft and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .scheduling_utils import SchedulerMixin @dataclass class VQDiffusionSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`): Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.LongTensor def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.Tensor: """ Convert batch of vector of class indices into batch of log onehot vectors Args: x (`torch.LongTensor` of shape `(batch size, vector length)`): Batch of class indices num_classes (`int`): number of classes to be used for the onehot vectors Returns: `torch.Tensor` of shape `(batch size, num classes, vector length)`: Log onehot vectors """ x_onehot = F.one_hot(x, num_classes) x_onehot = x_onehot.permute(0, 2, 1) log_x = torch.log(x_onehot.float().clamp(min=1e-30)) return log_x def gumbel_noised(logits: torch.Tensor, generator: Optional[torch.Generator]) -> torch.Tensor: """ Apply gumbel noise to `logits` """ uniform = torch.rand(logits.shape, device=logits.device, generator=generator) gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) noised = gumbel_noise + logits return noised def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009): """ Cumulative and non-cumulative alpha schedules. See section 4.1. """ att = ( np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start) + alpha_cum_start ) att = np.concatenate(([1], att)) at = att[1:] / att[:-1] att = np.concatenate((att[1:], [1])) return at, att def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999): """ Cumulative and non-cumulative gamma schedules. See section 4.1. """ ctt = ( np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start) + gamma_cum_start ) ctt = np.concatenate(([0], ctt)) one_minus_ctt = 1 - ctt one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1] ct = 1 - one_minus_ct ctt = np.concatenate((ctt[1:], [0])) return ct, ctt class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): """ A scheduler for vector quantized diffusion. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_vec_classes (`int`): The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. num_train_timesteps (`int`, defaults to 100): The number of diffusion steps to train the model. alpha_cum_start (`float`, defaults to 0.99999): The starting cumulative alpha value. alpha_cum_end (`float`, defaults to 0.00009): The ending cumulative alpha value. gamma_cum_start (`float`, defaults to 0.00009): The starting cumulative gamma value. gamma_cum_end (`float`, defaults to 0.99999): The ending cumulative gamma value. """ order = 1 @register_to_config def __init__( self, num_vec_classes: int, num_train_timesteps: int = 100, alpha_cum_start: float = 0.99999, alpha_cum_end: float = 0.000009, gamma_cum_start: float = 0.000009, gamma_cum_end: float = 0.99999, ): self.num_embed = num_vec_classes # By convention, the index for the mask class is the last class index self.mask_class = self.num_embed - 1 at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end) ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end) num_non_mask_classes = self.num_embed - 1 bt = (1 - at - ct) / num_non_mask_classes btt = (1 - att - ctt) / num_non_mask_classes at = torch.tensor(at.astype("float64")) bt = torch.tensor(bt.astype("float64")) ct = torch.tensor(ct.astype("float64")) log_at = torch.log(at) log_bt = torch.log(bt) log_ct = torch.log(ct) att = torch.tensor(att.astype("float64")) btt = torch.tensor(btt.astype("float64")) ctt = torch.tensor(ctt.astype("float64")) log_cumprod_at = torch.log(att) log_cumprod_bt = torch.log(btt) log_cumprod_ct = torch.log(ctt) self.log_at = log_at.float() self.log_bt = log_bt.float() self.log_ct = log_ct.float() self.log_cumprod_at = log_cumprod_at.float() self.log_cumprod_bt = log_cumprod_bt.float() self.log_cumprod_ct = log_cumprod_ct.float() # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps and diffusion process parameters (alpha, beta, gamma) should be moved to. """ self.num_inference_steps = num_inference_steps timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = torch.from_numpy(timesteps).to(device) self.log_at = self.log_at.to(device) self.log_bt = self.log_bt.to(device) self.log_ct = self.log_ct.to(device) self.log_cumprod_at = self.log_cumprod_at.to(device) self.log_cumprod_bt = self.log_cumprod_bt.to(device) self.log_cumprod_ct = self.log_cumprod_ct.to(device) def step( self, model_output: torch.Tensor, timestep: torch.long, sample: torch.LongTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[VQDiffusionSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by the reverse transition distribution. See [`~VQDiffusionScheduler.q_posterior`] for more details about how the distribution is computer. Args: log_p_x_0: (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`): The log probabilities for the predicted classes of the initial latent pixels. Does not include a prediction for the masked class as the initial unnoised image cannot be masked. t (`torch.long`): The timestep that determines which transition matrices are used. x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): The classes of each latent pixel at time `t`. generator (`torch.Generator`, or `None`): A random number generator for the noise applied to `p(x_{t-1} | x_t)` before it is sampled from. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if timestep == 0: log_p_x_t_min_1 = model_output else: log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep) log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator) x_t_min_1 = log_p_x_t_min_1.argmax(dim=1) if not return_dict: return (x_t_min_1,) return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1) def q_posterior(self, log_p_x_0, x_t, t): """ Calculates the log probabilities for the predicted classes of the image at timestep `t-1`: ``` p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) ) ``` Args: log_p_x_0 (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`): The log probabilities for the predicted classes of the initial latent pixels. Does not include a prediction for the masked class as the initial unnoised image cannot be masked. x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): The classes of each latent pixel at time `t`. t (`torch.Long`): The timestep that determines which transition matrix is used. Returns: `torch.Tensor` of shape `(batch size, num classes, num latent pixels)`: The log probabilities for the predicted classes of the image at timestep `t-1`. """ log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class( t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True ) log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class( t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False ) # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) # . . . # . . . # . . . # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) q = log_p_x_0 - log_q_x_t_given_x_0 # sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... , # sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True) # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n # . . . # . . . # . . . # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n q = q - q_log_sum_exp # (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} # . . . # . . . # . . . # (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} # c_cumulative_{t-1} ... c_cumulative_{t-1} q = self.apply_cumulative_transitions(q, t - 1) # ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n # . . . # . . . # . . . # ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n # c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp # For each column, there are two possible cases. # # Where: # - sum(p_n(x_0))) is summing over all classes for x_0 # - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's) # - C_j is the class transitioning to # # 1. x_t is masked i.e. x_t = c_k # # Simplifying the expression, the column vector is: # . # . # . # (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0))) # . # . # . # (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0)) # # From equation (11) stated in terms of forward probabilities, the last row is trivially verified. # # For the other rows, we can state the equation as ... # # (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})] # # This verifies the other rows. # # 2. x_t is not masked # # Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i: # . # . # . # C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) # . # . # . # C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) # . # . # . # 0 # # The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities. return log_p_x_t_min_1 def log_Q_t_transitioning_to_known_class( self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.Tensor, cumulative: bool ): """ Calculates the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each latent pixel in `x_t`. Args: t (`torch.Long`): The timestep that determines which transition matrix is used. x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): The classes of each latent pixel at time `t`. log_onehot_x_t (`torch.Tensor` of shape `(batch size, num classes, num latent pixels)`): The log one-hot vectors of `x_t`. cumulative (`bool`): If cumulative is `False`, the single step transition matrix `t-1`->`t` is used. If cumulative is `True`, the cumulative transition matrix `0`->`t` is used. Returns: `torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`: Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability transition matrix. When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be masked. Where: - `q_n` is the probability distribution for the forward process of the `n`th latent pixel. - C_0 is a class of a latent pixel embedding - C_k is the class of the masked latent pixel non-cumulative result (omitting logarithms): ``` q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0) . . . . . . . . . q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k) ``` cumulative result (omitting logarithms): ``` q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0) . . . . . . . . . q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1}) ``` """ if cumulative: a = self.log_cumprod_at[t] b = self.log_cumprod_bt[t] c = self.log_cumprod_ct[t] else: a = self.log_at[t] b = self.log_bt[t] c = self.log_ct[t] if not cumulative: # The values in the onehot vector can also be used as the logprobs for transitioning # from masked latent pixels. If we are not calculating the cumulative transitions, # we need to save these vectors to be re-appended to the final matrix so the values # aren't overwritten. # # `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector # if x_t is not masked # # `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector # if x_t is masked log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1) # `index_to_log_onehot` will add onehot vectors for masked pixels, # so the default one hot matrix has one too many rows. See the doc string # for an explanation of the dimensionality of the returned matrix. log_onehot_x_t = log_onehot_x_t[:, :-1, :] # this is a cheeky trick to produce the transition probabilities using log one-hot vectors. # # Don't worry about what values this sets in the columns that mark transitions # to masked latent pixels. They are overwrote later with the `mask_class_mask`. # # Looking at the below logspace formula in non-logspace, each value will evaluate to either # `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column # or # `0 * a + b = b` where `log_Q_t` has the 0 values in the column. # # See equation 7 for more details. log_Q_t = (log_onehot_x_t + a).logaddexp(b) # The whole column of each masked pixel is `c` mask_class_mask = x_t == self.mask_class mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1) log_Q_t[mask_class_mask] = c if not cumulative: log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1) return log_Q_t def apply_cumulative_transitions(self, q, t): bsz = q.shape[0] a = self.log_cumprod_at[t] b = self.log_cumprod_bt[t] c = self.log_cumprod_ct[t] num_latent_pixels = q.shape[2] c = c.expand(bsz, 1, num_latent_pixels) q = (q + a).logaddexp(b) q = torch.cat((q, c), dim=1) return q ================================================ FILE: src/diffusers/training_utils.py ================================================ import contextlib import copy import math import random from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch from .models import UNet2DConditionModel from .schedulers import SchedulerMixin from .utils import ( convert_state_dict_to_diffusers, convert_state_dict_to_peft, deprecate, is_peft_available, is_torch_npu_available, is_torchvision_available, is_transformers_available, ) if is_transformers_available(): import transformers if is_peft_available(): from peft import set_peft_model_state_dict if is_torchvision_available(): from torchvision import transforms if is_torch_npu_available(): import torch_npu # noqa: F401 def set_seed(seed: int): """ Args: Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. seed (`int`): The seed to set. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if is_torch_npu_available(): torch.npu.manual_seed_all(seed) else: torch.cuda.manual_seed_all(seed) # ^^ safe to call this function even if cuda is not available def compute_snr(noise_scheduler, timesteps): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 """ alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma) ** 2 return snr def resolve_interpolation_mode(interpolation_type: str): """ Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The full list of supported enums is documented at https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode. Args: interpolation_type (`str`): A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes in torchvision. Returns: `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` transform. """ if not is_torchvision_available(): raise ImportError( "Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function." ) if interpolation_type == "bilinear": interpolation_mode = transforms.InterpolationMode.BILINEAR elif interpolation_type == "bicubic": interpolation_mode = transforms.InterpolationMode.BICUBIC elif interpolation_type == "box": interpolation_mode = transforms.InterpolationMode.BOX elif interpolation_type == "nearest": interpolation_mode = transforms.InterpolationMode.NEAREST elif interpolation_type == "nearest_exact": interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT elif interpolation_type == "hamming": interpolation_mode = transforms.InterpolationMode.HAMMING elif interpolation_type == "lanczos": interpolation_mode = transforms.InterpolationMode.LANCZOS else: raise ValueError( f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." ) return interpolation_mode def compute_dream_and_update_latents( unet: UNet2DConditionModel, noise_scheduler: SchedulerMixin, timesteps: torch.Tensor, noise: torch.Tensor, noisy_latents: torch.Tensor, target: torch.Tensor, encoder_hidden_states: torch.Tensor, dream_detail_preservation: float = 1.0, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """ Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210. DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra forward step without gradients. Args: `unet`: The state unet to use to make a prediction. `noise_scheduler`: The noise scheduler used to add noise for the given timestep. `timesteps`: The timesteps for the noise_scheduler to user. `noise`: A tensor of noise in the shape of noisy_latents. `noisy_latents`: Previously noise latents from the training loop. `target`: The ground-truth tensor to predict after eps is removed. `encoder_hidden_states`: Text embeddings from the text model. `dream_detail_preservation`: A float value that indicates detail preservation level. See reference. Returns: `tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target. """ alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None] sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments. dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation pred = None with torch.no_grad(): pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample _noisy_latents, _target = (None, None) if noise_scheduler.config.prediction_type == "epsilon": predicted_noise = pred delta_noise = (noise - predicted_noise).detach() delta_noise.mul_(dream_lambda) _noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise) _target = target.add(delta_noise) elif noise_scheduler.config.prediction_type == "v_prediction": raise NotImplementedError("DREAM has not been implemented for v-prediction") else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") return _noisy_latents, _target def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: r""" Returns: A state dict containing just the LoRA parameters. """ lora_state_dict = {} for name, module in unet.named_modules(): if hasattr(module, "set_lora_layer"): lora_layer = getattr(module, "lora_layer") if lora_layer is not None: current_lora_layer_sd = lora_layer.state_dict() for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items(): # The matrix name can either be "down" or "up". lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param return lora_state_dict def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): if not isinstance(model, list): model = [model] for m in model: for param in m.parameters(): # only upcast trainable parameters into fp32 if param.requires_grad: param.data = param.to(dtype) def _set_state_dict_into_text_encoder( lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module ): """ Sets the `lora_state_dict` into `text_encoder` coming from `transformers`. Args: lora_state_dict: The state dictionary to be set. prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`. text_encoder: Where the `lora_state_dict` is to be set. """ text_encoder_state_dict = { f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix) } text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict)) set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") def compute_density_for_timestep_sampling( weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None ): """Compute the density for sampling the timesteps when doing SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. SD3 paper reference: https://arxiv.org/abs/2403.03206v1. """ if weighting_scheme == "logit_normal": # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") u = torch.nn.functional.sigmoid(u) elif weighting_scheme == "mode": u = torch.rand(size=(batch_size,), device="cpu") u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) else: u = torch.rand(size=(batch_size,), device="cpu") return u def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): """Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. SD3 paper reference: https://arxiv.org/abs/2403.03206v1. """ if weighting_scheme == "sigma_sqrt": weighting = (sigmas**-2.0).float() elif weighting_scheme == "cosmap": bot = 1 - 2 * sigmas + 2 * sigmas**2 weighting = 2 / (math.pi * bot) else: weighting = torch.ones_like(sigmas) return weighting # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights """ def __init__( self, parameters: Iterable[torch.nn.Parameter], decay: float = 0.9999, min_decay: float = 0.0, update_after_step: int = 0, use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, foreach: bool = False, model_cls: Optional[Any] = None, model_config: Dict[str, Any] = None, **kwargs, ): """ Args: parameters (Iterable[torch.nn.Parameter]): The parameters to track. decay (float): The decay factor for the exponential moving average. min_decay (float): The minimum decay factor for the exponential moving average. update_after_step (int): The number of steps to wait before starting to update the EMA weights. use_ema_warmup (bool): Whether to use EMA warmup. inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster. device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA weights will be stored on CPU. @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). """ if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " "Please pass the parameters of the module instead." ) deprecate( "passing a `torch.nn.Module` to `ExponentialMovingAverage`", "1.0.0", deprecation_message, standard_warn=False, ) parameters = parameters.parameters() # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility use_ema_warmup = True if kwargs.get("max_value", None) is not None: deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) decay = kwargs["max_value"] if kwargs.get("min_value", None) is not None: deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) min_decay = kwargs["min_value"] parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] if kwargs.get("device", None) is not None: deprecation_message = "The `device` argument is deprecated. Please use `to` instead." deprecate("device", "1.0.0", deprecation_message, standard_warn=False) self.to(device=kwargs["device"]) self.temp_stored_params = None self.decay = decay self.min_decay = min_decay self.update_after_step = update_after_step self.use_ema_warmup = use_ema_warmup self.inv_gamma = inv_gamma self.power = power self.optimization_step = 0 self.cur_decay_value = None # set in `step()` self.foreach = foreach self.model_cls = model_cls self.model_config = model_config @classmethod def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel": _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) model = model_cls.from_pretrained(path) ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach) ema_model.load_state_dict(ema_kwargs) return ema_model def save_pretrained(self, path): if self.model_cls is None: raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") if self.model_config is None: raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") model = self.model_cls.from_config(self.model_config) state_dict = self.state_dict() state_dict.pop("shadow_params", None) model.register_to_config(**state_dict) self.copy_to(model.parameters()) model.save_pretrained(path) def get_decay(self, optimization_step: int) -> float: """ Compute the decay factor for the exponential moving average. """ step = max(0, optimization_step - self.update_after_step - 1) if step <= 0: return 0.0 if self.use_ema_warmup: cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power else: cur_decay_value = (1 + step) / (10 + step) cur_decay_value = min(cur_decay_value, self.decay) # make sure decay is not smaller than min_decay cur_decay_value = max(cur_decay_value, self.min_decay) return cur_decay_value @torch.no_grad() def step(self, parameters: Iterable[torch.nn.Parameter]): if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " "Please pass the parameters of the module instead." ) deprecate( "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", "1.0.0", deprecation_message, standard_warn=False, ) parameters = parameters.parameters() parameters = list(parameters) self.optimization_step += 1 # Compute the decay factor for the exponential moving average. decay = self.get_decay(self.optimization_step) self.cur_decay_value = decay one_minus_decay = 1 - decay context_manager = contextlib.nullcontext if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed if self.foreach: if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None) with context_manager(): params_grad = [param for param in parameters if param.requires_grad] s_params_grad = [ s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad ] if len(params_grad) < len(parameters): torch._foreach_copy_( [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad], [param for param in parameters if not param.requires_grad], non_blocking=True, ) torch._foreach_sub_( s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay ) else: for s_param, param in zip(self.shadow_params, parameters): if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) with context_manager(): if param.requires_grad: s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param) def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) if self.foreach: torch._foreach_copy_( [param.data for param in parameters], [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)], ) else: for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.to(param.device).data) def pin_memory(self) -> None: r""" Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for offloading EMA params to the host. """ self.shadow_params = [p.pin_memory() for p in self.shadow_params] def to(self, device=None, dtype=None, non_blocking=False) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype, non_blocking=non_blocking) if p.is_floating_point() else p.to(device=device, non_blocking=non_blocking) for p in self.shadow_params ] def state_dict(self) -> dict: r""" Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during checkpointing to save the ema state dict. """ # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "min_decay": self.min_decay, "optimization_step": self.optimization_step, "update_after_step": self.update_after_step, "use_ema_warmup": self.use_ema_warmup, "inv_gamma": self.inv_gamma, "power": self.power, "shadow_params": self.shadow_params, } def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Save the current parameters for restoring later. parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After validation (or model saving), use this to restore the former parameters. parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") if self.foreach: torch._foreach_copy_( [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params] ) else: for c_param, param in zip(self.temp_stored_params, parameters): param.data.copy_(c_param.data) # Better memory-wise. self.temp_stored_params = None def load_state_dict(self, state_dict: dict) -> None: r""" Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict.get("decay", self.decay) if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.min_decay = state_dict.get("min_decay", self.min_decay) if not isinstance(self.min_decay, float): raise ValueError("Invalid min_decay") self.optimization_step = state_dict.get("optimization_step", self.optimization_step) if not isinstance(self.optimization_step, int): raise ValueError("Invalid optimization_step") self.update_after_step = state_dict.get("update_after_step", self.update_after_step) if not isinstance(self.update_after_step, int): raise ValueError("Invalid update_after_step") self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) if not isinstance(self.use_ema_warmup, bool): raise ValueError("Invalid use_ema_warmup") self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) if not isinstance(self.inv_gamma, (float, int)): raise ValueError("Invalid inv_gamma") self.power = state_dict.get("power", self.power) if not isinstance(self.power, (float, int)): raise ValueError("Invalid power") shadow_params = state_dict.get("shadow_params", None) if shadow_params is not None: self.shadow_params = shadow_params if not isinstance(self.shadow_params, list): raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") ================================================ FILE: src/diffusers/utils/__init__.py ================================================ # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from packaging import version from .. import __version__ from .constants import ( CONFIG_NAME, DEPRECATED_REVISION_ARGS, DIFFUSERS_DYNAMIC_MODULE_NAME, FLAX_WEIGHTS_NAME, HF_MODULES_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, MIN_PEFT_VERSION, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, SAFETENSORS_WEIGHTS_NAME, USE_PEFT_BACKEND, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ) from .deprecation_utils import deprecate from .doc_utils import replace_example_docstring from .dynamic_modules_utils import get_class_from_dynamic_module from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video from .hub_utils import ( PushToHubMixin, _add_variant, _get_checkpoint_shard_files, _get_model_file, extract_commit_hash, http_user_agent, ) from .import_utils import ( BACKENDS_MAPPING, DIFFUSERS_SLOW_IMPORT, ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_VALUES, USE_JAX, USE_TF, USE_TORCH, DummyObject, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, is_accelerate_available, is_accelerate_version, is_bitsandbytes_available, is_bs4_available, is_flax_available, is_ftfy_available, is_google_colab, is_inflect_available, is_invisible_watermark_available, is_k_diffusion_available, is_k_diffusion_version, is_librosa_available, is_matplotlib_available, is_note_seq_available, is_onnx_available, is_peft_available, is_peft_version, is_safetensors_available, is_scipy_available, is_sentencepiece_available, is_tensorboard_available, is_timm_available, is_torch_available, is_torch_npu_available, is_torch_version, is_torch_xla_available, is_torchsde_available, is_torchvision_available, is_transformers_available, is_transformers_version, is_unidecode_available, is_wandb_available, is_xformers_available, requires_backends, ) from .loading_utils import load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( check_peft_version, delete_adapter_layers, get_adapter_name, get_peft_kwargs, recurse_remove_peft_layers, scale_lora_layers, set_adapter_layers, set_weights_and_activate_adapters, unscale_lora_layers, ) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .state_dict_utils import ( convert_all_state_dict_to_peft, convert_state_dict_to_diffusers, convert_state_dict_to_kohya, convert_state_dict_to_peft, convert_unet_state_dict_to_peft, ) logger = get_logger(__name__) def check_min_version(min_version): if version.parse(__version__) < version.parse(min_version): if "dev" in min_version: error_message = ( "This example requires a source install from HuggingFace diffusers (see " "`https://huggingface.co/docs/diffusers/installation#install-from-source`)," ) else: error_message = f"This example requires a minimum version of {min_version}," error_message += f" but the version found is {__version__}.\n" raise ImportError(error_message) ================================================ FILE: src/diffusers/utils/accelerate_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Accelerate utilities: Utilities related to accelerate """ from packaging import version from .import_utils import is_accelerate_available if is_accelerate_available(): import accelerate def apply_forward_hook(method): """ Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. :param method: The method to decorate. This method should be a method of a PyTorch module. """ if not is_accelerate_available(): return method accelerate_version = version.parse(accelerate.__version__).base_version if version.parse(accelerate_version) < version.parse("0.17.0"): return method def wrapper(self, *args, **kwargs): if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): self._hf_hook.pre_forward(self) return method(self, *args, **kwargs) return wrapper ================================================ FILE: src/diffusers/utils/constants.py ================================================ # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import os from huggingface_hub.constants import HF_HOME from packaging import version from ..dependency_versions_check import dep_version_check from .import_utils import ENV_VARS_TRUE_VALUES, is_peft_available, is_transformers_available MIN_PEFT_VERSION = "0.6.0" MIN_TRANSFORMERS_VERSION = "4.34.0" _CHECK_PEFT = os.environ.get("_CHECK_PEFT", "1") in ENV_VARS_TRUE_VALUES CONFIG_NAME = "config.json" WEIGHTS_NAME = "diffusion_pytorch_model.bin" WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.bin.index.json" FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" ONNX_WEIGHTS_NAME = "model.onnx" SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json" SAFETENSORS_FILE_EXTENSION = "safetensors" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are # available. # For PEFT it is has to be greater than or equal to 0.6.0 and for transformers it has to be greater than or equal to 4.34.0. _required_peft_version = is_peft_available() and version.parse( version.parse(importlib.metadata.version("peft")).base_version ) >= version.parse(MIN_PEFT_VERSION) _required_transformers_version = is_transformers_available() and version.parse( version.parse(importlib.metadata.version("transformers")).base_version ) >= version.parse(MIN_TRANSFORMERS_VERSION) USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version if USE_PEFT_BACKEND and _CHECK_PEFT: dep_version_check("peft") ================================================ FILE: src/diffusers/utils/deprecation_utils.py ================================================ import inspect import warnings from typing import Any, Dict, Optional, Union from packaging import version def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): from .. import __version__ deprecated_kwargs = take_from values = () if not isinstance(args[0], tuple): args = (args,) for attribute, version_name, message in args: if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): raise ValueError( f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" f" version {__version__} is >= {version_name}" ) warning = None if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: values += (deprecated_kwargs.pop(attribute),) warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." elif hasattr(deprecated_kwargs, attribute): values += (getattr(deprecated_kwargs, attribute),) warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." elif deprecated_kwargs is None: warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." if warning is not None: warning = warning + " " if standard_warn else "" warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel) if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: call_frame = inspect.getouterframes(inspect.currentframe())[1] filename = call_frame.filename line_number = call_frame.lineno function = call_frame.function key, value = next(iter(deprecated_kwargs.items())) raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") if len(values) == 0: return elif len(values) == 1: return values[0] return values ================================================ FILE: src/diffusers/utils/doc_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Doc utilities: Utilities related to documentation """ import re def replace_example_docstring(example_docstring): def docstring_decorator(fn): func_doc = fn.__doc__ lines = func_doc.split("\n") i = 0 while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None: i += 1 if i < len(lines): lines[i] = example_docstring func_doc = "\n".join(lines) else: raise ValueError( f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, " f"current docstring is:\n{func_doc}" ) fn.__doc__ = func_doc return fn return docstring_decorator ================================================ FILE: src/diffusers/utils/dummy_flax_and_transformers_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) class FlaxStableDiffusionPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) class FlaxStableDiffusionXLPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) ================================================ FILE: src/diffusers/utils/dummy_flax_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class FlaxControlNetModel(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxModelMixin(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxUNet2DConditionModel(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxAutoencoderKL(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDiffusionPipeline(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDDIMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDDPMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxEulerDiscreteScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxKarrasVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxLMSDiscreteScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxPNDMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxSchedulerMixin(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) ================================================ FILE: src/diffusers/utils/dummy_note_seq_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class MidiProcessor(metaclass=DummyObject): _backends = ["note_seq"] def __init__(self, *args, **kwargs): requires_backends(self, ["note_seq"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["note_seq"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["note_seq"]) ================================================ FILE: src/diffusers/utils/dummy_onnx_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class OnnxRuntimeModel(metaclass=DummyObject): _backends = ["onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["onnx"]) ================================================ FILE: src/diffusers/utils/dummy_pt_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class AsymmetricAutoencoderKL(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class AuraFlowTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class AutoencoderKL(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class AutoencoderKLCogVideoX(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class AutoencoderKLTemporalDecoder(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class AutoencoderOobleck(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class AutoencoderTiny(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class CogVideoXTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ConsistencyDecoderVAE(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ControlNetXSAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DiTTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class FluxTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class HunyuanDiT2DControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class HunyuanDiT2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class Kandinsky3UNet(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class LatteTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class LuminaNextDiT2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ModelMixin(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class MotionAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class MultiAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class PixArtTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class PriorTransformer(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class SD3ControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class SD3MultiControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class SD3Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class SparseControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class StableAudioDiTModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class T2IAdapter(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class T5FilmDecoder(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet1DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet2DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet3DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNetControlNetXSModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNetMotionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNetSpatioTemporalConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UVit2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class VQModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) def get_constant_schedule_with_warmup(*args, **kwargs): requires_backends(get_constant_schedule_with_warmup, ["torch"]) def get_cosine_schedule_with_warmup(*args, **kwargs): requires_backends(get_cosine_schedule_with_warmup, ["torch"]) def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"]) def get_linear_schedule_with_warmup(*args, **kwargs): requires_backends(get_linear_schedule_with_warmup, ["torch"]) def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"]) def get_scheduler(*args, **kwargs): requires_backends(get_scheduler, ["torch"]) class AudioPipelineOutput(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class AutoPipelineForImage2Image(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class AutoPipelineForInpainting(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class AutoPipelineForText2Image(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class BlipDiffusionControlNetPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class BlipDiffusionPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class CLIPImageProjection(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ConsistencyModelPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DanceDiffusionPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDPMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DiffusionPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DiTPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ImagePipelineOutput(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KarrasVePipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class LDMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class LDMSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class PNDMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class RePaintPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ScoreSdeVePipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class StableDiffusionMixin(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class AmusedScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class CMStochasticIterativeScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class CogVideoXDDIMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class CogVideoXDPMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMInverseScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMParallelScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDPMParallelScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDPMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDPMWuerstchenScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DEISMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DPMSolverSinglestepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class EDMDPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class EDMEulerScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class EulerAncestralDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class EulerDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class FlowMatchEulerDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class FlowMatchHeunDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class HeunDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class IPNDMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KarrasVeScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KDPM2AncestralDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KDPM2DiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class LCMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class PNDMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class RePaintScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class SASolverScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class SchedulerMixin(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class TCDScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UnCLIPScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UniPCMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class VQDiffusionScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class EMAModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) ================================================ FILE: src/diffusers/utils/dummy_torch_and_librosa_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class AudioDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "librosa"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "librosa"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) class Mel(metaclass=DummyObject): _backends = ["torch", "librosa"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "librosa"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) ================================================ FILE: src/diffusers/utils/dummy_torch_and_scipy_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class LMSDiscreteScheduler(metaclass=DummyObject): _backends = ["torch", "scipy"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "scipy"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "scipy"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "scipy"]) ================================================ FILE: src/diffusers/utils/dummy_torch_and_torchsde_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class CosineDPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["torch", "torchsde"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "torchsde"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "torchsde"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "torchsde"]) class DPMSolverSDEScheduler(metaclass=DummyObject): _backends = ["torch", "torchsde"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "torchsde"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "torchsde"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "torchsde"]) ================================================ FILE: src/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class StableDiffusionKDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "k_diffusion"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "k_diffusion"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "k_diffusion"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "k_diffusion"]) class StableDiffusionXLKDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "k_diffusion"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "k_diffusion"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "k_diffusion"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "k_diffusion"]) ================================================ FILE: src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class StableDiffusionOnnxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) ================================================ FILE: src/diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class KolorsImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "sentencepiece"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "sentencepiece"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "sentencepiece"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "sentencepiece"]) class KolorsPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "sentencepiece"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "sentencepiece"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "sentencepiece"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "sentencepiece"]) class KolorsPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "sentencepiece"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "sentencepiece"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "sentencepiece"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "sentencepiece"]) ================================================ FILE: src/diffusers/utils/dummy_torch_and_transformers_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AltDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AmusedImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AmusedInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AmusedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AnimateDiffControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AnimateDiffPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AnimateDiffPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AnimateDiffSDXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AnimateDiffSparseControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AudioLDM2Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AudioLDM2ProjectionModel(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AudioLDM2UNet2DConditionModel(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AudioLDMPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AuraFlowPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class CLIPImageProjection(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class CogVideoXPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class CycleDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class FluxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class HunyuanDiTControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class HunyuanDiTPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class HunyuanDiTPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class I2VGenXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFInpaintingPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class ImageTextPipelineOutput(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class Kandinsky3Img2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class Kandinsky3Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyImg2ImgCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyInpaintCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyPriorPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22CombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22ControlnetImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22ControlnetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22Img2ImgCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22Img2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22InpaintCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22InpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22PriorEmb2EmbPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22PriorPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class LatentConsistencyModelImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class LatentConsistencyModelPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class LattePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class LDMTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class LEditsPPPipelineStableDiffusion(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class LuminaText2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class MarigoldDepthPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class MarigoldNormalsPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class MusicLDMPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class PaintByExamplePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class PIAPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class PixArtAlphaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class PixArtSigmaPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class PixArtSigmaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class SemanticStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class ShapEImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class ShapEPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableAudioPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableAudioProjectionModel(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableCascadeCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableCascadeDecoderPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableCascadePriorPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusion3ControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusion3Img2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusion3InpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusion3PAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusion3Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionAdapterPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionControlNetImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionControlNetPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionControlNetXSPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionDiffEditPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionGLIGENPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionGLIGENTextImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionInstructPix2PixPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionLatentUpscalePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionLDM3DPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionModelEditingPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPanoramaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionParadigmsPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPipelineSafe(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionSAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionUpscalePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLAdapterPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLControlNetImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLControlNetInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLControlNetPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLPAGImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLPAGInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLPAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableUnCLIPPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableVideoDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class TextToVideoSDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class TextToVideoZeroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class TextToVideoZeroSDXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UnCLIPImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UnCLIPPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UniDiffuserModel(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UniDiffuserPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UniDiffuserTextDecoder(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VideoToVideoSDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VQDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class WuerstchenCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class WuerstchenDecoderPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class WuerstchenPriorPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) ================================================ FILE: src/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class SpectrogramDiffusionPipeline(metaclass=DummyObject): _backends = ["transformers", "torch", "note_seq"] def __init__(self, *args, **kwargs): requires_backends(self, ["transformers", "torch", "note_seq"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["transformers", "torch", "note_seq"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["transformers", "torch", "note_seq"]) ================================================ FILE: src/diffusers/utils/dynamic_modules_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities to dynamically load objects from the Hub.""" import importlib import inspect import json import os import re import shutil import sys from pathlib import Path from typing import Dict, Optional, Union from urllib import request from huggingface_hub import hf_hub_download, model_info from huggingface_hub.utils import RevisionNotFoundError, validate_hf_hub_args from packaging import version from .. import __version__ from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name # See https://huggingface.co/datasets/diffusers/community-pipelines-mirror COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" def get_diffusers_versions(): url = "https://pypi.org/pypi/diffusers/json" releases = json.loads(request.urlopen(url).read())["releases"].keys() return sorted(releases, key=lambda x: version.Version(x)) def init_hf_modules(): """ Creates the cache directory for modules with an init, and adds it to the Python path. """ # This function has already been executed if HF_MODULES_CACHE already is in the Python path. if HF_MODULES_CACHE in sys.path: return sys.path.append(HF_MODULES_CACHE) os.makedirs(HF_MODULES_CACHE, exist_ok=True) init_path = Path(HF_MODULES_CACHE) / "__init__.py" if not init_path.exists(): init_path.touch() def create_dynamic_module(name: Union[str, os.PathLike]): """ Creates a dynamic module in the cache directory for modules. """ init_hf_modules() dynamic_module_path = Path(HF_MODULES_CACHE) / name # If the parent module does not exist yet, recursively create it. if not dynamic_module_path.parent.exists(): create_dynamic_module(dynamic_module_path.parent) os.makedirs(dynamic_module_path, exist_ok=True) init_path = dynamic_module_path / "__init__.py" if not init_path.exists(): init_path.touch() def get_relative_imports(module_file): """ Get the list of modules that are relatively imported in a module file. Args: module_file (`str` or `os.PathLike`): The module file to inspect. """ with open(module_file, "r", encoding="utf-8") as f: content = f.read() # Imports of the form `import .xxx` relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) # Imports of the form `from .xxx import yyy` relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) # Unique-ify return list(set(relative_imports)) def get_relative_import_files(module_file): """ Get the list of all files that are needed for a given module. Note that this function recurses through the relative imports (if a imports b and b imports c, it will return module files for b and c). Args: module_file (`str` or `os.PathLike`): The module file to inspect. """ no_change = False files_to_check = [module_file] all_relative_imports = [] # Let's recurse through all relative imports while not no_change: new_imports = [] for f in files_to_check: new_imports.extend(get_relative_imports(f)) module_path = Path(module_file).parent new_import_files = [str(module_path / m) for m in new_imports] new_import_files = [f for f in new_import_files if f not in all_relative_imports] files_to_check = [f"{f}.py" for f in new_import_files] no_change = len(new_import_files) == 0 all_relative_imports.extend(files_to_check) return all_relative_imports def check_imports(filename): """ Check if the current Python environment contains all the libraries that are imported in a file. """ with open(filename, "r", encoding="utf-8") as f: content = f.read() # Imports of the form `import xxx` imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) # Imports of the form `from xxx import yyy` imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) # Only keep the top-level module imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] # Unique-ify and test we got them all imports = list(set(imports)) missing_packages = [] for imp in imports: try: importlib.import_module(imp) except ImportError: missing_packages.append(imp) if len(missing_packages) > 0: raise ImportError( "This modeling file requires the following packages that were not found in your environment: " f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" ) return get_relative_imports(filename) def get_class_in_module(class_name, module_path): """ Import a module on the cache directory for modules and extract a class from it. """ module_path = module_path.replace(os.path.sep, ".") module = importlib.import_module(module_path) if class_name is None: return find_pipeline_class(module) return getattr(module, class_name) def find_pipeline_class(loaded_module): """ Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class inheriting from `DiffusionPipeline`. """ from ..pipelines import DiffusionPipeline cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass)) pipeline_class = None for cls_name, cls in cls_members.items(): if ( cls_name != DiffusionPipeline.__name__ and issubclass(cls, DiffusionPipeline) and cls.__module__.split(".")[0] != "diffusers" ): if pipeline_class is not None: raise ValueError( f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:" f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in" f" {loaded_module}." ) pipeline_class = cls return pipeline_class @validate_hf_hub_args def get_cached_module_file( pretrained_model_name_or_path: Union[str, os.PathLike], module_file: str, cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, proxies: Optional[Dict[str, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, ): """ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached Transformers module. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: - a string, the *model id* of a pretrained model configuration hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - a path to a *directory* containing a configuration file saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. module_file (`str`): The name of the module file containing the class to look for. cache_dir (`str` or `os.PathLike`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force to (re-)download the configuration files and override the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, will only try to load the tokenizer configuration from local files. You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). Returns: `str`: The path to the module inside the cache. """ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. pretrained_model_name_or_path = str(pretrained_model_name_or_path) module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) if os.path.isfile(module_file_or_url): resolved_module_file = module_file_or_url submodule = "local" elif pretrained_model_name_or_path.count("/") == 0: available_versions = get_diffusers_versions() # cut ".dev0" latest_version = "v" + ".".join(__version__.split(".")[:3]) # retrieve github version that matches if revision is None: revision = latest_version if latest_version[1:] in available_versions else "main" logger.info(f"Defaulting to latest_version: {revision}.") elif revision in available_versions: revision = f"v{revision}" elif revision == "main": revision = revision else: raise ValueError( f"`custom_revision`: {revision} does not exist. Please make sure to choose one of" f" {', '.join(available_versions + ['main'])}." ) try: resolved_module_file = hf_hub_download( repo_id=COMMUNITY_PIPELINES_MIRROR_ID, repo_type="dataset", filename=f"{revision}/{pretrained_model_name_or_path}.py", cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, ) submodule = "git" module_file = pretrained_model_name_or_path + ".py" except RevisionNotFoundError as e: raise EnvironmentError( f"Revision '{revision}' not found in the community pipelines mirror. Check available revisions on" " https://huggingface.co/datasets/diffusers/community-pipelines-mirror/tree/main." " If you don't find the revision you are looking for, please open an issue on https://github.com/huggingface/diffusers/issues." ) from e except EnvironmentError: logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") raise else: try: # Load from URL or cache if already cached resolved_module_file = hf_hub_download( pretrained_model_name_or_path, module_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, ) submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) except EnvironmentError: logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") raise # Check we have all the requirements in our environment modules_needed = check_imports(resolved_module_file) # Now we move the module inside our cached dynamic modules. full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule create_dynamic_module(full_submodule) submodule_path = Path(HF_MODULES_CACHE) / full_submodule if submodule == "local" or submodule == "git": # We always copy local files (we could hash the file to see if there was a change, and give them the name of # that hash, to only copy when there is a modification but it seems overkill for now). # The only reason we do the copy is to avoid putting too many folders in sys.path. shutil.copy(resolved_module_file, submodule_path / module_file) for module_needed in modules_needed: if len(module_needed.split(".")) == 2: module_needed = "/".join(module_needed.split(".")) module_folder = module_needed.split("/")[0] if not os.path.exists(submodule_path / module_folder): os.makedirs(submodule_path / module_folder) module_needed = f"{module_needed}.py" shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) else: # Get the commit hash # TODO: we will get this info in the etag soon, so retrieve it from there and not here. commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the # benefit of versioning. submodule_path = submodule_path / commit_hash full_submodule = full_submodule + os.path.sep + commit_hash create_dynamic_module(full_submodule) if not (submodule_path / module_file).exists(): if len(module_file.split("/")) == 2: module_folder = module_file.split("/")[0] if not os.path.exists(submodule_path / module_folder): os.makedirs(submodule_path / module_folder) shutil.copy(resolved_module_file, submodule_path / module_file) # Make sure we also have every file with relative for module_needed in modules_needed: if len(module_needed.split(".")) == 2: module_needed = "/".join(module_needed.split(".")) if not (submodule_path / module_needed).exists(): get_cached_module_file( pretrained_model_name_or_path, f"{module_needed}.py", cache_dir=cache_dir, force_download=force_download, proxies=proxies, token=token, revision=revision, local_files_only=local_files_only, ) return os.path.join(full_submodule, module_file) @validate_hf_hub_args def get_class_from_dynamic_module( pretrained_model_name_or_path: Union[str, os.PathLike], module_file: str, class_name: Optional[str] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, proxies: Optional[Dict[str, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, **kwargs, ): """ Extracts a class from a module file, present in the local folder or repository of a model. Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should therefore only be called on trusted repos. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: - a string, the *model id* of a pretrained model configuration hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - a path to a *directory* containing a configuration file saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. module_file (`str`): The name of the module file containing the class to look for. class_name (`str`): The name of the class to import in the module. cache_dir (`str` or `os.PathLike`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force to (re-)download the configuration files and override the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, will only try to load the tokenizer configuration from local files. You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). Returns: `type`: The class, dynamically imported from the module. Examples: ```python # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this # module. cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") ```""" # And lastly we get the class inside our newly created module final_module = get_cached_module_file( pretrained_model_name_or_path, module_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, token=token, revision=revision, local_files_only=local_files_only, ) return get_class_in_module(class_name, final_module.replace(".py", "")) ================================================ FILE: src/diffusers/utils/export_utils.py ================================================ import io import random import struct import tempfile from contextlib import contextmanager from typing import List, Union import numpy as np import PIL.Image import PIL.ImageOps from .import_utils import BACKENDS_MAPPING, is_opencv_available from .logging import get_logger global_rng = random.Random() logger = get_logger(__name__) @contextmanager def buffered_writer(raw_f): f = io.BufferedWriter(raw_f) yield f f.flush() def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None, fps: int = 10) -> str: if output_gif_path is None: output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name image[0].save( output_gif_path, save_all=True, append_images=image[1:], optimize=False, duration=1000 // fps, loop=0, ) return output_gif_path def export_to_ply(mesh, output_ply_path: str = None): """ Write a PLY file for a mesh. """ if output_ply_path is None: output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name coords = mesh.verts.detach().cpu().numpy() faces = mesh.faces.cpu().numpy() rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1) with buffered_writer(open(output_ply_path, "wb")) as f: f.write(b"ply\n") f.write(b"format binary_little_endian 1.0\n") f.write(bytes(f"element vertex {len(coords)}\n", "ascii")) f.write(b"property float x\n") f.write(b"property float y\n") f.write(b"property float z\n") if rgb is not None: f.write(b"property uchar red\n") f.write(b"property uchar green\n") f.write(b"property uchar blue\n") if faces is not None: f.write(bytes(f"element face {len(faces)}\n", "ascii")) f.write(b"property list uchar int vertex_index\n") f.write(b"end_header\n") if rgb is not None: rgb = (rgb * 255.499).round().astype(int) vertices = [ (*coord, *rgb) for coord, rgb in zip( coords.tolist(), rgb.tolist(), ) ] format = struct.Struct("<3f3B") for item in vertices: f.write(format.pack(*item)) else: format = struct.Struct("<3f") for vertex in coords.tolist(): f.write(format.pack(*vertex)) if faces is not None: format = struct.Struct(" str: if is_opencv_available(): import cv2 else: raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name if isinstance(video_frames[0], np.ndarray): video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] elif isinstance(video_frames[0], PIL.Image.Image): video_frames = [np.array(frame) for frame in video_frames] fourcc = cv2.VideoWriter_fourcc(*"mp4v") h, w, c = video_frames[0].shape video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) for i in range(len(video_frames)): img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) video_writer.write(img) return output_video_path ================================================ FILE: src/diffusers/utils/hub_utils.py ================================================ # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import re import sys import tempfile import traceback import warnings from pathlib import Path from typing import Dict, List, Optional, Union from uuid import uuid4 from huggingface_hub import ( ModelCard, ModelCardData, create_repo, hf_hub_download, model_info, snapshot_download, upload_folder, ) from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE from huggingface_hub.file_download import REGEX_COMMIT_HASH from huggingface_hub.utils import ( EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, is_jinja_available, validate_hf_hub_args, ) from packaging import version from requests import HTTPError from .. import __version__ from .constants import ( DEPRECATED_REVISION_ARGS, HUGGINGFACE_CO_RESOLVE_ENDPOINT, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, ) from .import_utils import ( ENV_VARS_TRUE_VALUES, _flax_version, _jax_version, _onnxruntime_version, _torch_version, is_flax_available, is_onnx_available, is_torch_available, ) from .logging import get_logger logger = get_logger(__name__) MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md" SESSION_ID = uuid4().hex def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: """ Formats a user-agent string with basic info about a request. """ ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" if HF_HUB_DISABLE_TELEMETRY or HF_HUB_OFFLINE: return ua + "; telemetry/off" if is_torch_available(): ua += f"; torch/{_torch_version}" if is_flax_available(): ua += f"; jax/{_jax_version}" ua += f"; flax/{_flax_version}" if is_onnx_available(): ua += f"; onnxruntime/{_onnxruntime_version}" # CI will set this value to True if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: ua += "; is_ci/true" if isinstance(user_agent, dict): ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) elif isinstance(user_agent, str): ua += "; " + user_agent return ua def load_or_create_model_card( repo_id_or_path: str = None, token: Optional[str] = None, is_pipeline: bool = False, from_training: bool = False, model_description: Optional[str] = None, base_model: str = None, prompt: Optional[str] = None, license: Optional[str] = None, widget: Optional[List[dict]] = None, inference: Optional[bool] = None, ) -> ModelCard: """ Loads or creates a model card. Args: repo_id_or_path (`str`): The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card. token (`str`, *optional*): Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details. is_pipeline (`bool`): Boolean to indicate if we're adding tag to a [`DiffusionPipeline`]. from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script. model_description (`str`, *optional*): Model description to add to the model card. Helpful when using `load_or_create_model_card` from a training script. base_model (`str`): Base model identifier (e.g., "stabilityai/stable-diffusion-xl-base-1.0"). Useful for DreamBooth-like training. prompt (`str`, *optional*): Prompt used for training. Useful for DreamBooth-like training. license: (`str`, *optional*): License of the output artifact. Helpful when using `load_or_create_model_card` from a training script. widget (`List[dict]`, *optional*): Widget to accompany a gallery template. inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using `load_or_create_model_card` from a training script. """ if not is_jinja_available(): raise ValueError( "Modelcard rendering is based on Jinja templates." " Please make sure to have `jinja` installed before using `load_or_create_model_card`." " To install it, please run `pip install Jinja2`." ) try: # Check if the model card is present on the remote repo model_card = ModelCard.load(repo_id_or_path, token=token) except (EntryNotFoundError, RepositoryNotFoundError): # Otherwise create a model card from template if from_training: model_card = ModelCard.from_template( card_data=ModelCardData( # Card metadata object that will be converted to YAML block license=license, library_name="diffusers", inference=inference, base_model=base_model, instance_prompt=prompt, widget=widget, ), template_path=MODEL_CARD_TEMPLATE_PATH, model_description=model_description, ) else: card_data = ModelCardData() component = "pipeline" if is_pipeline else "model" if model_description is None: model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated." model_card = ModelCard.from_template(card_data, model_description=model_description) return model_card def populate_model_card(model_card: ModelCard, tags: Union[str, List[str]] = None) -> ModelCard: """Populates the `model_card` with library name and optional tags.""" if model_card.data.library_name is None: model_card.data.library_name = "diffusers" if tags is not None: if isinstance(tags, str): tags = [tags] if model_card.data.tags is None: model_card.data.tags = [] for tag in tags: model_card.data.tags.append(tag) return model_card def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None): """ Extracts the commit hash from a resolved filename toward a cache file. """ if resolved_file is None or commit_hash is not None: return commit_hash resolved_file = str(Path(resolved_file).as_posix()) search = re.search(r"snapshots/([^/]+)/", resolved_file) if search is None: return None commit_hash = search.groups()[0] return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None # Old default cache path, potentially to be migrated. # This logic was more or less taken from `transformers`, with the following differences: # - Diffusers doesn't use custom environment variables to specify the cache path. # - There is no need to migrate the cache format, just move the files to the new location. hf_cache_home = os.path.expanduser( os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) ) old_diffusers_cache = os.path.join(hf_cache_home, "diffusers") def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None: if new_cache_dir is None: new_cache_dir = HF_HUB_CACHE if old_cache_dir is None: old_cache_dir = old_diffusers_cache old_cache_dir = Path(old_cache_dir).expanduser() new_cache_dir = Path(new_cache_dir).expanduser() for old_blob_path in old_cache_dir.glob("**/blobs/*"): if old_blob_path.is_file() and not old_blob_path.is_symlink(): new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir) new_blob_path.parent.mkdir(parents=True, exist_ok=True) os.replace(old_blob_path, new_blob_path) try: os.symlink(new_blob_path, old_blob_path) except OSError: logger.warning( "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded." ) # At this point, old_cache_dir contains symlinks to the new cache (it can still be used). cache_version_file = os.path.join(HF_HUB_CACHE, "version_diffusers_cache.txt") if not os.path.isfile(cache_version_file): cache_version = 0 else: with open(cache_version_file) as f: try: cache_version = int(f.read()) except ValueError: cache_version = 0 if cache_version < 1: old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0 if old_cache_is_not_empty: logger.warning( "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your " "existing cached models. This is a one-time operation, you can interrupt it or run it " "later by calling `diffusers.utils.hub_utils.move_cache()`." ) try: move_cache() except Exception as e: trace = "\n".join(traceback.format_tb(e.__traceback__)) logger.error( f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease " "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole " "message and we will do our best to help." ) if cache_version < 1: try: os.makedirs(HF_HUB_CACHE, exist_ok=True) with open(cache_version_file, "w") as f: f.write("1") except Exception: logger.warning( f"There was a problem when trying to write in your cache folder ({HF_HUB_CACHE}). Please, ensure " "the directory exists and can be written to." ) def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: splits = weights_name.split(".") split_index = -2 if weights_name.endswith(".index.json") else -1 splits = splits[:-split_index] + [variant] + splits[-split_index:] weights_name = ".".join(splits) return weights_name @validate_hf_hub_args def _get_model_file( pretrained_model_name_or_path: Union[str, Path], *, weights_name: str, subfolder: Optional[str] = None, cache_dir: Optional[str] = None, force_download: bool = False, proxies: Optional[Dict] = None, local_files_only: bool = False, token: Optional[str] = None, user_agent: Optional[Union[Dict, str]] = None, revision: Optional[str] = None, commit_hash: Optional[str] = None, ): pretrained_model_name_or_path = str(pretrained_model_name_or_path) if os.path.isfile(pretrained_model_name_or_path): return pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): # Load from a PyTorch checkpoint model_file = os.path.join(pretrained_model_name_or_path, weights_name) return model_file elif subfolder is not None and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, weights_name) ): model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) return model_file else: raise EnvironmentError( f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." ) else: # 1. First check if deprecated way of loading from branches is used if ( revision in DEPRECATED_REVISION_ARGS and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) and version.parse(version.parse(__version__).base_version) >= version.parse("0.22.0") ): try: model_file = hf_hub_download( pretrained_model_name_or_path, filename=_add_variant(weights_name, revision), cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, user_agent=user_agent, subfolder=subfolder, revision=revision or commit_hash, ) warnings.warn( f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", FutureWarning, ) return model_file except: # noqa: E722 warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", FutureWarning, ) try: # 2. Load model file as usual model_file = hf_hub_download( pretrained_model_name_or_path, filename=weights_name, cache_dir=cache_dir, force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, user_agent=user_agent, subfolder=subfolder, revision=revision or commit_hash, ) return model_file except RepositoryNotFoundError as e: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " "token having permission to this repo with `token` or log in with `huggingface-cli " "login`." ) from e except RevisionNotFoundError as e: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " "this model name. Check the model page at " f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) from e except EntryNotFoundError as e: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." ) from e except HTTPError as e: raise EnvironmentError( f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{e}" ) from e except ValueError as e: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" f" directory containing a file named {weights_name} or" " \nCheckout your internet connection or see how to run the library in" " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." ) from e except EnvironmentError as e: raise EnvironmentError( f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a file named {weights_name}" ) from e # Adapted from # https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976 # Differences are in parallelization of shard downloads and checking if shards are present. def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames): shards_path = os.path.join(local_dir, subfolder) shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames] for shard_file in shard_filenames: if not os.path.exists(shard_file): raise ValueError( f"{shards_path} does not appear to have a file named {shard_file} which is " "required according to the checkpoint index." ) def _get_checkpoint_shard_files( pretrained_model_name_or_path, index_filename, cache_dir=None, proxies=None, local_files_only=False, token=None, user_agent=None, revision=None, subfolder="", ): """ For a given model: - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the Hub - returns the list of paths to all the shards, as well as some metadata. For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). """ if not os.path.isfile(index_filename): raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") with open(index_filename, "r") as f: index = json.loads(f.read()) original_shard_filenames = sorted(set(index["weight_map"].values())) sharded_metadata = index["metadata"] sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) sharded_metadata["weight_map"] = index["weight_map"].copy() shards_path = os.path.join(pretrained_model_name_or_path, subfolder) # First, let's deal with local folder. if os.path.isdir(pretrained_model_name_or_path): _check_if_shards_exist_locally( pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames ) return shards_path, sharded_metadata # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames if subfolder is not None: allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns] ignore_patterns = ["*.json", "*.md"] if not local_files_only: # `model_info` call must guarded with the above condition. model_files_info = model_info(pretrained_model_name_or_path, revision=revision) for shard_file in original_shard_filenames: shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) if not shard_file_present: raise EnvironmentError( f"{shards_path} does not appear to have a file named {shard_file} which is " "required according to the checkpoint index." ) try: # Load from URL cached_folder = snapshot_download( pretrained_model_name_or_path, cache_dir=cache_dir, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, user_agent=user_agent, ) if subfolder is not None: cached_folder = os.path.join(cached_folder, subfolder) # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so # we don't have to catch them here. We have also dealt with EntryNotFoundError. except HTTPError as e: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" " again after checking your internet connection." ) from e # If `local_files_only=True`, `cached_folder` may not contain all the shard files. elif local_files_only: _check_if_shards_exist_locally( local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames ) if subfolder is not None: cached_folder = os.path.join(cached_folder, subfolder) return cached_folder, sharded_metadata class PushToHubMixin: """ A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub. """ def _upload_folder( self, working_dir: Union[str, os.PathLike], repo_id: str, token: Optional[str] = None, commit_message: Optional[str] = None, create_pr: bool = False, ): """ Uploads all files in `working_dir` to `repo_id`. """ if commit_message is None: if "Model" in self.__class__.__name__: commit_message = "Upload model" elif "Scheduler" in self.__class__.__name__: commit_message = "Upload scheduler" else: commit_message = f"Upload {self.__class__.__name__}" logger.info(f"Uploading the files of {working_dir} to {repo_id}.") return upload_folder( repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr ) def push_to_hub( self, repo_id: str, commit_message: Optional[str] = None, private: Optional[bool] = None, token: Optional[str] = None, create_pr: bool = False, safe_serialization: bool = True, variant: Optional[str] = None, ) -> str: """ Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub. Parameters: repo_id (`str`): The name of the repository you want to push your model, scheduler, or pipeline files to. It should contain your organization name when pushing to an organization. `repo_id` can also be a path to a local directory. commit_message (`str`, *optional*): Message to commit while pushing. Default to `"Upload {object}"`. private (`bool`, *optional*): Whether or not the repository created should be private. token (`str`, *optional*): The token to use as HTTP bearer authorization for remote files. The token generated when running `huggingface-cli login` (stored in `~/.huggingface`). create_pr (`bool`, *optional*, defaults to `False`): Whether or not to create a PR with the uploaded files or directly commit. safe_serialization (`bool`, *optional*, defaults to `True`): Whether or not to convert the model weights to the `safetensors` format. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. Examples: ```python from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="unet") # Push the `unet` to your namespace with the name "my-finetuned-unet". unet.push_to_hub("my-finetuned-unet") # Push the `unet` to an organization with the name "my-finetuned-unet". unet.push_to_hub("your-org/my-finetuned-unet") ``` """ repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id # Create a new empty model card and eventually tag it model_card = load_or_create_model_card(repo_id, token=token) model_card = populate_model_card(model_card) # Save all files. save_kwargs = {"safe_serialization": safe_serialization} if "Scheduler" not in self.__class__.__name__: save_kwargs.update({"variant": variant}) with tempfile.TemporaryDirectory() as tmpdir: self.save_pretrained(tmpdir, **save_kwargs) # Update model card if needed: model_card.save(os.path.join(tmpdir, "README.md")) return self._upload_folder( tmpdir, repo_id, token=token, commit_message=commit_message, create_pr=create_pr, ) ================================================ FILE: src/diffusers/utils/import_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Import utilities: Utilities related to imports and our lazy inits. """ import importlib.util import operator as op import os import sys from collections import OrderedDict from itertools import chain from types import ModuleType from typing import Any, Union from huggingface_hub.utils import is_jinja_available # noqa: F401 from packaging import version from packaging.version import Version, parse from . import logging # The package importlib_metadata is in a different place, depending on the python version. if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata logger = logging.get_logger(__name__) # pylint: disable=invalid-name ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper() DIFFUSERS_SLOW_IMPORT = os.environ.get("DIFFUSERS_SLOW_IMPORT", "FALSE").upper() DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None if _torch_available: try: _torch_version = importlib_metadata.version("torch") logger.info(f"PyTorch version {_torch_version} available.") except importlib_metadata.PackageNotFoundError: _torch_available = False else: logger.info("Disabling PyTorch because USE_TORCH is set") _torch_available = False _torch_xla_available = importlib.util.find_spec("torch_xla") is not None if _torch_xla_available: try: _torch_xla_version = importlib_metadata.version("torch_xla") logger.info(f"PyTorch XLA version {_torch_xla_version} available.") except ImportError: _torch_xla_available = False # check whether torch_npu is available _torch_npu_available = importlib.util.find_spec("torch_npu") is not None if _torch_npu_available: try: _torch_npu_version = importlib_metadata.version("torch_npu") logger.info(f"torch_npu version {_torch_npu_version} available.") except ImportError: _torch_npu_available = False _jax_version = "N/A" _flax_version = "N/A" if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None if _flax_available: try: _jax_version = importlib_metadata.version("jax") _flax_version = importlib_metadata.version("flax") logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") except importlib_metadata.PackageNotFoundError: _flax_available = False else: _flax_available = False if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: _safetensors_available = importlib.util.find_spec("safetensors") is not None if _safetensors_available: try: _safetensors_version = importlib_metadata.version("safetensors") logger.info(f"Safetensors version {_safetensors_version} available.") except importlib_metadata.PackageNotFoundError: _safetensors_available = False else: logger.info("Disabling Safetensors because USE_TF is set") _safetensors_available = False _transformers_available = importlib.util.find_spec("transformers") is not None try: _transformers_version = importlib_metadata.version("transformers") logger.debug(f"Successfully imported transformers version {_transformers_version}") except importlib_metadata.PackageNotFoundError: _transformers_available = False _inflect_available = importlib.util.find_spec("inflect") is not None try: _inflect_version = importlib_metadata.version("inflect") logger.debug(f"Successfully imported inflect version {_inflect_version}") except importlib_metadata.PackageNotFoundError: _inflect_available = False _unidecode_available = importlib.util.find_spec("unidecode") is not None try: _unidecode_version = importlib_metadata.version("unidecode") logger.debug(f"Successfully imported unidecode version {_unidecode_version}") except importlib_metadata.PackageNotFoundError: _unidecode_available = False _onnxruntime_version = "N/A" _onnx_available = importlib.util.find_spec("onnxruntime") is not None if _onnx_available: candidates = ( "onnxruntime", "onnxruntime-gpu", "ort_nightly_gpu", "onnxruntime-directml", "onnxruntime-openvino", "ort_nightly_directml", "onnxruntime-rocm", "onnxruntime-training", ) _onnxruntime_version = None # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu for pkg in candidates: try: _onnxruntime_version = importlib_metadata.version(pkg) break except importlib_metadata.PackageNotFoundError: pass _onnx_available = _onnxruntime_version is not None if _onnx_available: logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") # (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. # _opencv_available = importlib.util.find_spec("opencv-python") is not None try: candidates = ( "opencv-python", "opencv-contrib-python", "opencv-python-headless", "opencv-contrib-python-headless", ) _opencv_version = None for pkg in candidates: try: _opencv_version = importlib_metadata.version(pkg) break except importlib_metadata.PackageNotFoundError: pass _opencv_available = _opencv_version is not None if _opencv_available: logger.debug(f"Successfully imported cv2 version {_opencv_version}") except importlib_metadata.PackageNotFoundError: _opencv_available = False _scipy_available = importlib.util.find_spec("scipy") is not None try: _scipy_version = importlib_metadata.version("scipy") logger.debug(f"Successfully imported scipy version {_scipy_version}") except importlib_metadata.PackageNotFoundError: _scipy_available = False _librosa_available = importlib.util.find_spec("librosa") is not None try: _librosa_version = importlib_metadata.version("librosa") logger.debug(f"Successfully imported librosa version {_librosa_version}") except importlib_metadata.PackageNotFoundError: _librosa_available = False _accelerate_available = importlib.util.find_spec("accelerate") is not None try: _accelerate_version = importlib_metadata.version("accelerate") logger.debug(f"Successfully imported accelerate version {_accelerate_version}") except importlib_metadata.PackageNotFoundError: _accelerate_available = False _xformers_available = importlib.util.find_spec("xformers") is not None try: _xformers_version = importlib_metadata.version("xformers") if _torch_available: _torch_version = importlib_metadata.version("torch") if version.Version(_torch_version) < version.Version("1.12"): raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12") logger.debug(f"Successfully imported xformers version {_xformers_version}") except importlib_metadata.PackageNotFoundError: _xformers_available = False _k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None try: _k_diffusion_version = importlib_metadata.version("k_diffusion") logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}") except importlib_metadata.PackageNotFoundError: _k_diffusion_available = False _note_seq_available = importlib.util.find_spec("note_seq") is not None try: _note_seq_version = importlib_metadata.version("note_seq") logger.debug(f"Successfully imported note-seq version {_note_seq_version}") except importlib_metadata.PackageNotFoundError: _note_seq_available = False _wandb_available = importlib.util.find_spec("wandb") is not None try: _wandb_version = importlib_metadata.version("wandb") logger.debug(f"Successfully imported wandb version {_wandb_version }") except importlib_metadata.PackageNotFoundError: _wandb_available = False _tensorboard_available = importlib.util.find_spec("tensorboard") try: _tensorboard_version = importlib_metadata.version("tensorboard") logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") except importlib_metadata.PackageNotFoundError: _tensorboard_available = False _compel_available = importlib.util.find_spec("compel") try: _compel_version = importlib_metadata.version("compel") logger.debug(f"Successfully imported compel version {_compel_version}") except importlib_metadata.PackageNotFoundError: _compel_available = False _ftfy_available = importlib.util.find_spec("ftfy") is not None try: _ftfy_version = importlib_metadata.version("ftfy") logger.debug(f"Successfully imported ftfy version {_ftfy_version}") except importlib_metadata.PackageNotFoundError: _ftfy_available = False _bs4_available = importlib.util.find_spec("bs4") is not None try: # importlib metadata under different name _bs4_version = importlib_metadata.version("beautifulsoup4") logger.debug(f"Successfully imported ftfy version {_bs4_version}") except importlib_metadata.PackageNotFoundError: _bs4_available = False _torchsde_available = importlib.util.find_spec("torchsde") is not None try: _torchsde_version = importlib_metadata.version("torchsde") logger.debug(f"Successfully imported torchsde version {_torchsde_version}") except importlib_metadata.PackageNotFoundError: _torchsde_available = False _invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None try: _invisible_watermark_version = importlib_metadata.version("invisible-watermark") logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}") except importlib_metadata.PackageNotFoundError: _invisible_watermark_available = False _peft_available = importlib.util.find_spec("peft") is not None try: _peft_version = importlib_metadata.version("peft") logger.debug(f"Successfully imported peft version {_peft_version}") except importlib_metadata.PackageNotFoundError: _peft_available = False _torchvision_available = importlib.util.find_spec("torchvision") is not None try: _torchvision_version = importlib_metadata.version("torchvision") logger.debug(f"Successfully imported torchvision version {_torchvision_version}") except importlib_metadata.PackageNotFoundError: _torchvision_available = False _sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None try: _sentencepiece_version = importlib_metadata.version("sentencepiece") logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}") except importlib_metadata.PackageNotFoundError: _sentencepiece_available = False _matplotlib_available = importlib.util.find_spec("matplotlib") is not None try: _matplotlib_version = importlib_metadata.version("matplotlib") logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}") except importlib_metadata.PackageNotFoundError: _matplotlib_available = False _timm_available = importlib.util.find_spec("timm") is not None if _timm_available: try: _timm_version = importlib_metadata.version("timm") logger.info(f"Timm version {_timm_version} available.") except importlib_metadata.PackageNotFoundError: _timm_available = False def is_timm_available(): return _timm_available _bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None try: _bitsandbytes_version = importlib_metadata.version("bitsandbytes") logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}") except importlib_metadata.PackageNotFoundError: _bitsandbytes_available = False _is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ) def is_torch_available(): return _torch_available def is_torch_xla_available(): return _torch_xla_available def is_torch_npu_available(): return _torch_npu_available def is_flax_available(): return _flax_available def is_transformers_available(): return _transformers_available def is_inflect_available(): return _inflect_available def is_unidecode_available(): return _unidecode_available def is_onnx_available(): return _onnx_available def is_opencv_available(): return _opencv_available def is_scipy_available(): return _scipy_available def is_librosa_available(): return _librosa_available def is_xformers_available(): return _xformers_available def is_accelerate_available(): return _accelerate_available def is_k_diffusion_available(): return _k_diffusion_available def is_note_seq_available(): return _note_seq_available def is_wandb_available(): return _wandb_available def is_tensorboard_available(): return _tensorboard_available def is_compel_available(): return _compel_available def is_ftfy_available(): return _ftfy_available def is_bs4_available(): return _bs4_available def is_torchsde_available(): return _torchsde_available def is_invisible_watermark_available(): return _invisible_watermark_available def is_peft_available(): return _peft_available def is_torchvision_available(): return _torchvision_available def is_matplotlib_available(): return _matplotlib_available def is_safetensors_available(): return _safetensors_available def is_bitsandbytes_available(): return _bitsandbytes_available def is_google_colab(): return _is_google_colab def is_sentencepiece_available(): return _sentencepiece_available # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the installation page: https://github.com/google/flax and follow the ones that match your environment. """ # docstyle-ignore INFLECT_IMPORT_ERROR = """ {0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install inflect` """ # docstyle-ignore PYTORCH_IMPORT_ERROR = """ {0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. """ # docstyle-ignore ONNX_IMPORT_ERROR = """ {0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip install onnxruntime` """ # docstyle-ignore OPENCV_IMPORT_ERROR = """ {0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip install opencv-python` """ # docstyle-ignore SCIPY_IMPORT_ERROR = """ {0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install scipy` """ # docstyle-ignore LIBROSA_IMPORT_ERROR = """ {0} requires the librosa library but it was not found in your environment. Checkout the instructions on the installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment. """ # docstyle-ignore TRANSFORMERS_IMPORT_ERROR = """ {0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip install transformers` """ # docstyle-ignore UNIDECODE_IMPORT_ERROR = """ {0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install Unidecode` """ # docstyle-ignore K_DIFFUSION_IMPORT_ERROR = """ {0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip install k-diffusion` """ # docstyle-ignore NOTE_SEQ_IMPORT_ERROR = """ {0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip install note-seq` """ # docstyle-ignore WANDB_IMPORT_ERROR = """ {0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip install wandb` """ # docstyle-ignore TENSORBOARD_IMPORT_ERROR = """ {0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip install tensorboard` """ # docstyle-ignore COMPEL_IMPORT_ERROR = """ {0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel` """ # docstyle-ignore BS4_IMPORT_ERROR = """ {0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: `pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation. """ # docstyle-ignore FTFY_IMPORT_ERROR = """ {0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ # docstyle-ignore TORCHSDE_IMPORT_ERROR = """ {0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` """ # docstyle-ignore INVISIBLE_WATERMARK_IMPORT_ERROR = """ {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0` """ # docstyle-ignore PEFT_IMPORT_ERROR = """ {0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft` """ # docstyle-ignore SAFETENSORS_IMPORT_ERROR = """ {0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors` """ # docstyle-ignore SENTENCEPIECE_IMPORT_ERROR = """ {0} requires the sentencepiece library but it was not found in your environment. You can install it with pip: `pip install sentencepiece` """ # docstyle-ignore BITSANDBYTES_IMPORT_ERROR = """ {0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes` """ BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ] ) def requires_backends(obj, backends): if not isinstance(backends, (list, tuple)): backends = [backends] name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ checks = (BACKENDS_MAPPING[backend] for backend in backends) failed = [msg.format(name) for available, msg in checks if not available()] if failed: raise ImportError("".join(failed)) if name in [ "VersatileDiffusionTextToImagePipeline", "VersatileDiffusionPipeline", "VersatileDiffusionDualGuidedPipeline", "StableDiffusionImageVariationPipeline", "UnCLIPPipeline", ] and is_transformers_version("<", "4.25.0"): raise ImportError( f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install" " --upgrade transformers \n```" ) if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version( "<", "4.26.0" ): raise ImportError( f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install" " --upgrade transformers \n```" ) class DummyObject(type): """ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by `requires_backend` each time a user tries to access any method of that class. """ def __getattr__(cls, key): if key.startswith("_") and key not in ["_load_connected_pipes", "_is_onnx"]: return super().__getattr__(cls, key) requires_backends(cls, cls._backends) # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): """ Args: Compares a library version to some requirement using a given operation. library_or_version (`str` or `packaging.version.Version`): A library name or a version to check. operation (`str`): A string representation of an operator, such as `">"` or `"<="`. requirement_version (`str`): The version to compare the library version against """ if operation not in STR_OPERATION_TO_FUNC.keys(): raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") operation = STR_OPERATION_TO_FUNC[operation] if isinstance(library_or_version, str): library_or_version = parse(importlib_metadata.version(library_or_version)) return operation(library_or_version, parse(requirement_version)) # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 def is_torch_version(operation: str, version: str): """ Args: Compares the current PyTorch version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A string version of PyTorch """ return compare_versions(parse(_torch_version), operation, version) def is_transformers_version(operation: str, version: str): """ Args: Compares the current Transformers version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A version string """ if not _transformers_available: return False return compare_versions(parse(_transformers_version), operation, version) def is_accelerate_version(operation: str, version: str): """ Args: Compares the current Accelerate version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A version string """ if not _accelerate_available: return False return compare_versions(parse(_accelerate_version), operation, version) def is_peft_version(operation: str, version: str): """ Args: Compares the current PEFT version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A version string """ if not _peft_version: return False return compare_versions(parse(_peft_version), operation, version) def is_k_diffusion_version(operation: str, version: str): """ Args: Compares the current k-diffusion version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A version string """ if not _k_diffusion_available: return False return compare_versions(parse(_k_diffusion_version), operation, version) def get_objects_from_module(module): """ Args: Returns a dict of object names and values in a module, while skipping private/internal objects module (ModuleType): Module to extract the objects from. Returns: dict: Dictionary of object names and corresponding values """ objects = {} for name in dir(module): if name.startswith("_"): continue objects[name] = getattr(module, name) return objects class OptionalDependencyNotAvailable(BaseException): """An error indicating that an optional dependency of Diffusers was not found in the environment.""" class _LazyModule(ModuleType): """ Module class that surfaces all objects but only performs associated imports when the objects are requested. """ # Very heavily inspired by optuna.integration._IntegrationModule # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None): super().__init__(name) self._modules = set(import_structure.keys()) self._class_to_module = {} for key, values in import_structure.items(): for value in values: self._class_to_module[value] = key # Needed for autocompletion in an IDE self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) self.__file__ = module_file self.__spec__ = module_spec self.__path__ = [os.path.dirname(module_file)] self._objects = {} if extra_objects is None else extra_objects self._name = name self._import_structure = import_structure # Needed for autocompletion in an IDE def __dir__(self): result = super().__dir__() # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. for attr in self.__all__: if attr not in result: result.append(attr) return result def __getattr__(self, name: str) -> Any: if name in self._objects: return self._objects[name] if name in self._modules: value = self._get_module(name) elif name in self._class_to_module.keys(): module = self._get_module(self._class_to_module[name]) value = getattr(module, name) else: raise AttributeError(f"module {self.__name__} has no attribute {name}") setattr(self, name, value) return value def _get_module(self, module_name: str): try: return importlib.import_module("." + module_name, self.__name__) except Exception as e: raise RuntimeError( f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its" f" traceback):\n{e}" ) from e def __reduce__(self): return (self.__class__, (self._name, self.__file__, self._import_structure)) ================================================ FILE: src/diffusers/utils/loading_utils.py ================================================ import os import tempfile from typing import Callable, List, Optional, Union import PIL.Image import PIL.ImageOps import requests from .import_utils import BACKENDS_MAPPING, is_opencv_available def load_image( image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None ) -> PIL.Image.Image: """ Loads `image` to a PIL Image. Args: image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*): A conversion method to apply to the image after loading it. When set to `None` the image will be converted "RGB". Returns: `PIL.Image.Image`: A PIL Image. """ if isinstance(image, str): if image.startswith("http://") or image.startswith("https://"): image = PIL.Image.open(requests.get(image, stream=True).raw) elif os.path.isfile(image): image = PIL.Image.open(image) else: raise ValueError( f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." ) elif isinstance(image, PIL.Image.Image): image = image else: raise ValueError( "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." ) image = PIL.ImageOps.exif_transpose(image) if convert_method is not None: image = convert_method(image) else: image = image.convert("RGB") return image def load_video( video: str, convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, ) -> List[PIL.Image.Image]: """ Loads `video` to a list of PIL Image. Args: video (`str`): A URL or Path to a video to convert to a list of PIL Image format. convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): A conversion method to apply to the video after loading it. When set to `None` the images will be converted to "RGB". Returns: `List[PIL.Image.Image]`: The video as a list of PIL images. """ is_url = video.startswith("http://") or video.startswith("https://") is_file = os.path.isfile(video) was_tempfile_created = False if not (is_url or is_file): raise ValueError( f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." ) if is_url: video_data = requests.get(video, stream=True).raw video_path = tempfile.NamedTemporaryFile(suffix=os.path.splitext(video)[1], delete=False).name was_tempfile_created = True with open(video_path, "wb") as f: f.write(video_data.read()) video = video_path pil_images = [] if video.endswith(".gif"): gif = PIL.Image.open(video) try: while True: pil_images.append(gif.copy()) gif.seek(gif.tell() + 1) except EOFError: pass else: if is_opencv_available(): import cv2 else: raise ImportError(BACKENDS_MAPPING["opencv"][1].format("load_video")) video_capture = cv2.VideoCapture(video) success, frame = video_capture.read() while success: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_images.append(PIL.Image.fromarray(frame)) success, frame = video_capture.read() video_capture.release() if was_tempfile_created: os.remove(video_path) if convert_method is not None: pil_images = convert_method(pil_images) return pil_images ================================================ FILE: src/diffusers/utils/logging.py ================================================ # coding=utf-8 # Copyright 2024 Optuna, Hugging Face # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Logging utilities.""" import logging import os import sys import threading from logging import ( CRITICAL, # NOQA DEBUG, # NOQA ERROR, # NOQA FATAL, # NOQA INFO, # NOQA NOTSET, # NOQA WARN, # NOQA WARNING, # NOQA ) from typing import Dict, Optional from tqdm import auto as tqdm_lib _lock = threading.Lock() _default_handler: Optional[logging.Handler] = None log_levels = { "debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING, "error": logging.ERROR, "critical": logging.CRITICAL, } _default_log_level = logging.WARNING _tqdm_active = True def _get_default_logging_level() -> int: """ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is not - fall back to `_default_log_level` """ env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) if env_level_str: if env_level_str in log_levels: return log_levels[env_level_str] else: logging.getLogger().warning( f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " f"has to be one of: { ', '.join(log_levels.keys()) }" ) return _default_log_level def _get_library_name() -> str: return __name__.split(".")[0] def _get_library_root_logger() -> logging.Logger: return logging.getLogger(_get_library_name()) def _configure_library_root_logger() -> None: global _default_handler with _lock: if _default_handler: # This library has already configured the library root logger. return _default_handler = logging.StreamHandler() # Set sys.stderr as stream. if sys.stderr: # only if sys.stderr exists, e.g. when not using pythonw in windows _default_handler.flush = sys.stderr.flush # Apply our default configuration to the library root logger. library_root_logger = _get_library_root_logger() library_root_logger.addHandler(_default_handler) library_root_logger.setLevel(_get_default_logging_level()) library_root_logger.propagate = False def _reset_library_root_logger() -> None: global _default_handler with _lock: if not _default_handler: return library_root_logger = _get_library_root_logger() library_root_logger.removeHandler(_default_handler) library_root_logger.setLevel(logging.NOTSET) _default_handler = None def get_log_levels_dict() -> Dict[str, int]: return log_levels def get_logger(name: Optional[str] = None) -> logging.Logger: """ Return a logger with the specified name. This function is not supposed to be directly accessed unless you are writing a custom diffusers module. """ if name is None: name = _get_library_name() _configure_library_root_logger() return logging.getLogger(name) def get_verbosity() -> int: """ Return the current level for the 🤗 Diffusers' root logger as an `int`. Returns: `int`: Logging level integers which can be one of: - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` - `40`: `diffusers.logging.ERROR` - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN` - `20`: `diffusers.logging.INFO` - `10`: `diffusers.logging.DEBUG` """ _configure_library_root_logger() return _get_library_root_logger().getEffectiveLevel() def set_verbosity(verbosity: int) -> None: """ Set the verbosity level for the 🤗 Diffusers' root logger. Args: verbosity (`int`): Logging level which can be one of: - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` - `diffusers.logging.ERROR` - `diffusers.logging.WARNING` or `diffusers.logging.WARN` - `diffusers.logging.INFO` - `diffusers.logging.DEBUG` """ _configure_library_root_logger() _get_library_root_logger().setLevel(verbosity) def set_verbosity_info() -> None: """Set the verbosity to the `INFO` level.""" return set_verbosity(INFO) def set_verbosity_warning() -> None: """Set the verbosity to the `WARNING` level.""" return set_verbosity(WARNING) def set_verbosity_debug() -> None: """Set the verbosity to the `DEBUG` level.""" return set_verbosity(DEBUG) def set_verbosity_error() -> None: """Set the verbosity to the `ERROR` level.""" return set_verbosity(ERROR) def disable_default_handler() -> None: """Disable the default handler of the 🤗 Diffusers' root logger.""" _configure_library_root_logger() assert _default_handler is not None _get_library_root_logger().removeHandler(_default_handler) def enable_default_handler() -> None: """Enable the default handler of the 🤗 Diffusers' root logger.""" _configure_library_root_logger() assert _default_handler is not None _get_library_root_logger().addHandler(_default_handler) def add_handler(handler: logging.Handler) -> None: """adds a handler to the HuggingFace Diffusers' root logger.""" _configure_library_root_logger() assert handler is not None _get_library_root_logger().addHandler(handler) def remove_handler(handler: logging.Handler) -> None: """removes given handler from the HuggingFace Diffusers' root logger.""" _configure_library_root_logger() assert handler is not None and handler in _get_library_root_logger().handlers _get_library_root_logger().removeHandler(handler) def disable_propagation() -> None: """ Disable propagation of the library log outputs. Note that log propagation is disabled by default. """ _configure_library_root_logger() _get_library_root_logger().propagate = False def enable_propagation() -> None: """ Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent double logging if the root logger has been configured. """ _configure_library_root_logger() _get_library_root_logger().propagate = True def enable_explicit_format() -> None: """ Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows: ``` [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE ``` All handlers currently bound to the root logger are affected by this method. """ handlers = _get_library_root_logger().handlers for handler in handlers: formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") handler.setFormatter(formatter) def reset_format() -> None: """ Resets the formatting for 🤗 Diffusers' loggers. All handlers currently bound to the root logger are affected by this method. """ handlers = _get_library_root_logger().handlers for handler in handlers: handler.setFormatter(None) def warning_advice(self, *args, **kwargs) -> None: """ This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this warning will not be printed """ no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) if no_advisory_warnings: return self.warning(*args, **kwargs) logging.Logger.warning_advice = warning_advice class EmptyTqdm: """Dummy tqdm which doesn't do anything.""" def __init__(self, *args, **kwargs): # pylint: disable=unused-argument self._iterator = args[0] if args else None def __iter__(self): return iter(self._iterator) def __getattr__(self, _): """Return empty function.""" def empty_fn(*args, **kwargs): # pylint: disable=unused-argument return return empty_fn def __enter__(self): return self def __exit__(self, type_, value, traceback): return class _tqdm_cls: def __call__(self, *args, **kwargs): if _tqdm_active: return tqdm_lib.tqdm(*args, **kwargs) else: return EmptyTqdm(*args, **kwargs) def set_lock(self, *args, **kwargs): self._lock = None if _tqdm_active: return tqdm_lib.tqdm.set_lock(*args, **kwargs) def get_lock(self): if _tqdm_active: return tqdm_lib.tqdm.get_lock() tqdm = _tqdm_cls() def is_progress_bar_enabled() -> bool: """Return a boolean indicating whether tqdm progress bars are enabled.""" global _tqdm_active return bool(_tqdm_active) def enable_progress_bar() -> None: """Enable tqdm progress bar.""" global _tqdm_active _tqdm_active = True def disable_progress_bar() -> None: """Disable tqdm progress bar.""" global _tqdm_active _tqdm_active = False ================================================ FILE: src/diffusers/utils/model_card_template.md ================================================ --- {{ card_data }} --- {{ model_description }} ## Intended uses & limitations #### How to use ```python # TODO: add an example code snippet for running this diffusion pipeline ``` #### Limitations and bias [TODO: provide examples of latent issues and potential remediations] ## Training details [TODO: describe the data used to train the model] ================================================ FILE: src/diffusers/utils/outputs.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Generic utilities """ from collections import OrderedDict from dataclasses import fields, is_dataclass from typing import Any, Tuple import numpy as np from .import_utils import is_torch_available, is_torch_version def is_tensor(x) -> bool: """ Tests if `x` is a `torch.Tensor` or `np.ndarray`. """ if is_torch_available(): import torch if isinstance(x, torch.Tensor): return True return isinstance(x, np.ndarray) class BaseOutput(OrderedDict): """ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular Python dictionary. You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple first. """ def __init_subclass__(cls) -> None: """Register subclasses as pytree nodes. This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with `static_graph=True` with modules that output `ModelOutput` subclasses. """ if is_torch_available(): import torch.utils._pytree if is_torch_version("<", "2.2"): torch.utils._pytree._register_pytree_node( cls, torch.utils._pytree._dict_flatten, lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), ) else: torch.utils._pytree.register_pytree_node( cls, torch.utils._pytree._dict_flatten, lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), ) def __post_init__(self) -> None: class_fields = fields(self) # Safety and consistency checks if not len(class_fields): raise ValueError(f"{self.__class__.__name__} has no fields.") first_field = getattr(self, class_fields[0].name) other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) if other_fields_are_none and isinstance(first_field, dict): for key, value in first_field.items(): self[key] = value else: for field in class_fields: v = getattr(self, field.name) if v is not None: self[field.name] = v def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") def setdefault(self, *args, **kwargs): raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") def pop(self, *args, **kwargs): raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") def update(self, *args, **kwargs): raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") def __getitem__(self, k: Any) -> Any: if isinstance(k, str): inner_dict = dict(self.items()) return inner_dict[k] else: return self.to_tuple()[k] def __setattr__(self, name: Any, value: Any) -> None: if name in self.keys() and value is not None: # Don't call self.__setitem__ to avoid recursion errors super().__setitem__(name, value) super().__setattr__(name, value) def __setitem__(self, key, value): # Will raise a KeyException if needed super().__setitem__(key, value) # Don't call self.__setattr__ to avoid recursion errors super().__setattr__(key, value) def __reduce__(self): if not is_dataclass(self): return super().__reduce__() callable, _args, *remaining = super().__reduce__() args = tuple(getattr(self, field.name) for field in fields(self)) return callable, args, *remaining def to_tuple(self) -> Tuple[Any, ...]: """ Convert self to a tuple containing all the attributes/keys that are not `None`. """ return tuple(self[k] for k in self.keys()) ================================================ FILE: src/diffusers/utils/peft_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PEFT utilities: Utilities related to peft library """ import collections import importlib from typing import Optional from packaging import version from .import_utils import is_peft_available, is_torch_available if is_torch_available(): import torch def recurse_remove_peft_layers(model): r""" Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`. """ from peft.tuners.tuners_utils import BaseTunerLayer has_base_layer_pattern = False for module in model.modules(): if isinstance(module, BaseTunerLayer): has_base_layer_pattern = hasattr(module, "base_layer") break if has_base_layer_pattern: from peft.utils import _get_submodules key_list = [key for key, _ in model.named_modules() if "lora" not in key] for key in key_list: try: parent, target, target_name = _get_submodules(model, key) except AttributeError: continue if hasattr(target, "base_layer"): setattr(parent, target_name, target.get_base_layer()) else: # This is for backwards compatibility with PEFT <= 0.6.2. # TODO can be removed once that PEFT version is no longer supported. from peft.tuners.lora import LoraLayer for name, module in model.named_children(): if len(list(module.children())) > 0: ## compound module, go inside it recurse_remove_peft_layers(module) module_replaced = False if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): new_module = torch.nn.Linear( module.in_features, module.out_features, bias=module.bias is not None, ).to(module.weight.device) new_module.weight = module.weight if module.bias is not None: new_module.bias = module.bias module_replaced = True elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d): new_module = torch.nn.Conv2d( module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, ).to(module.weight.device) new_module.weight = module.weight if module.bias is not None: new_module.bias = module.bias module_replaced = True if module_replaced: setattr(model, name, new_module) del module if torch.cuda.is_available(): torch.cuda.empty_cache() return model def scale_lora_layers(model, weight): """ Adjust the weightage given to the LoRA layers of the model. Args: model (`torch.nn.Module`): The model to scale. weight (`float`): The weight to be given to the LoRA layers. """ from peft.tuners.tuners_utils import BaseTunerLayer if weight == 1.0: return for module in model.modules(): if isinstance(module, BaseTunerLayer): module.scale_layer(weight) def unscale_lora_layers(model, weight: Optional[float] = None): """ Removes the previously passed weight given to the LoRA layers of the model. Args: model (`torch.nn.Module`): The model to scale. weight (`float`, *optional*): The weight to be given to the LoRA layers. If no scale is passed the scale of the lora layer will be re-initialized to the correct value. If 0.0 is passed, we will re-initialize the scale with the correct value. """ from peft.tuners.tuners_utils import BaseTunerLayer if weight == 1.0: return for module in model.modules(): if isinstance(module, BaseTunerLayer): if weight is not None and weight != 0: module.unscale_layer(weight) elif weight is not None and weight == 0: for adapter_name in module.active_adapters: # if weight == 0 unscale should re-set the scale to the original value. module.set_scale(adapter_name, 1.0) def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] if len(set(rank_dict.values())) > 1: # get the rank occuring the most number of times r = collections.Counter(rank_dict.values()).most_common()[0][0] # for modules with rank different from the most occuring rank, add it to the `rank_pattern` rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} if network_alpha_dict is not None and len(network_alpha_dict) > 0: if len(set(network_alpha_dict.values())) > 1: # get the alpha occuring the most number of times lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) if is_unet: alpha_pattern = { ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v for k, v in alpha_pattern.items() } else: alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} else: lora_alpha = set(network_alpha_dict.values()).pop() # layer names without the Diffusers specific target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) lora_config_kwargs = { "r": r, "lora_alpha": lora_alpha, "rank_pattern": rank_pattern, "alpha_pattern": alpha_pattern, "target_modules": target_modules, "use_dora": use_dora, } return lora_config_kwargs def get_adapter_name(model): from peft.tuners.tuners_utils import BaseTunerLayer for module in model.modules(): if isinstance(module, BaseTunerLayer): return f"default_{len(module.r)}" return "default_0" def set_adapter_layers(model, enabled=True): from peft.tuners.tuners_utils import BaseTunerLayer for module in model.modules(): if isinstance(module, BaseTunerLayer): # The recent version of PEFT needs to call `enable_adapters` instead if hasattr(module, "enable_adapters"): module.enable_adapters(enabled=enabled) else: module.disable_adapters = not enabled def delete_adapter_layers(model, adapter_name): from peft.tuners.tuners_utils import BaseTunerLayer for module in model.modules(): if isinstance(module, BaseTunerLayer): if hasattr(module, "delete_adapter"): module.delete_adapter(adapter_name) else: raise ValueError( "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1" ) # For transformers integration - we need to pop the adapter from the config if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"): model.peft_config.pop(adapter_name, None) # In case all adapters are deleted, we need to delete the config # and make sure to set the flag to False if len(model.peft_config) == 0: del model.peft_config model._hf_peft_config_loaded = None def set_weights_and_activate_adapters(model, adapter_names, weights): from peft.tuners.tuners_utils import BaseTunerLayer def get_module_weight(weight_for_adapter, module_name): if not isinstance(weight_for_adapter, dict): # If weight_for_adapter is a single number, always return it. return weight_for_adapter for layer_name, weight_ in weight_for_adapter.items(): if layer_name in module_name: return weight_ parts = module_name.split(".") # e.g. key = "down_blocks.1.attentions.0" key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}" block_weight = weight_for_adapter.get(key, 1.0) return block_weight # iterate over each adapter, make it active and set the corresponding scaling weight for adapter_name, weight in zip(adapter_names, weights): for module_name, module in model.named_modules(): if isinstance(module, BaseTunerLayer): # For backward compatbility with previous PEFT versions if hasattr(module, "set_adapter"): module.set_adapter(adapter_name) else: module.active_adapter = adapter_name module.set_scale(adapter_name, get_module_weight(weight, module_name)) # set multiple active adapters for module in model.modules(): if isinstance(module, BaseTunerLayer): # For backward compatbility with previous PEFT versions if hasattr(module, "set_adapter"): module.set_adapter(adapter_names) else: module.active_adapter = adapter_names def check_peft_version(min_version: str) -> None: r""" Checks if the version of PEFT is compatible. Args: version (`str`): The version of PEFT to check against. """ if not is_peft_available(): raise ValueError("PEFT is not installed. Please install it with `pip install peft`") is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version) if not is_peft_version_compatible: raise ValueError( f"The version of PEFT you are using is not compatible, please use a version that is greater" f" than {min_version}" ) ================================================ FILE: src/diffusers/utils/pil_utils.py ================================================ from typing import List import PIL.Image import PIL.ImageOps from packaging import version from PIL import Image if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): PIL_INTERPOLATION = { "linear": PIL.Image.Resampling.BILINEAR, "bilinear": PIL.Image.Resampling.BILINEAR, "bicubic": PIL.Image.Resampling.BICUBIC, "lanczos": PIL.Image.Resampling.LANCZOS, "nearest": PIL.Image.Resampling.NEAREST, } else: PIL_INTERPOLATION = { "linear": PIL.Image.LINEAR, "bilinear": PIL.Image.BILINEAR, "bicubic": PIL.Image.BICUBIC, "lanczos": PIL.Image.LANCZOS, "nearest": PIL.Image.NEAREST, } def pt_to_pil(images): """ Convert a torch image to a PIL image. """ images = (images / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).float().numpy() images = numpy_to_pil(images) return images def numpy_to_pil(images): """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image: """ Prepares a single grid of images. Useful for visualization purposes. """ assert len(images) == rows * cols if resize is not None: images = [img.resize((resize, resize)) for img in images] w, h = images[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(images): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid ================================================ FILE: src/diffusers/utils/state_dict_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ State dict utilities: utility methods for converting state dicts easily """ import enum from .logging import get_logger logger = get_logger(__name__) class StateDictType(enum.Enum): """ The mode to use when converting state dicts. """ DIFFUSERS_OLD = "diffusers_old" KOHYA_SS = "kohya_ss" PEFT = "peft" DIFFUSERS = "diffusers" # We need to define a proper mapping for Unet since it uses different output keys than text encoder # e.g. to_q_lora -> q_proj / to_q UNET_TO_DIFFUSERS = { ".to_out_lora.up": ".to_out.0.lora_B", ".to_out_lora.down": ".to_out.0.lora_A", ".to_q_lora.down": ".to_q.lora_A", ".to_q_lora.up": ".to_q.lora_B", ".to_k_lora.down": ".to_k.lora_A", ".to_k_lora.up": ".to_k.lora_B", ".to_v_lora.down": ".to_v.lora_A", ".to_v_lora.up": ".to_v.lora_B", ".lora.up": ".lora_B", ".lora.down": ".lora_A", ".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector", } DIFFUSERS_TO_PEFT = { ".q_proj.lora_linear_layer.up": ".q_proj.lora_B", ".q_proj.lora_linear_layer.down": ".q_proj.lora_A", ".k_proj.lora_linear_layer.up": ".k_proj.lora_B", ".k_proj.lora_linear_layer.down": ".k_proj.lora_A", ".v_proj.lora_linear_layer.up": ".v_proj.lora_B", ".v_proj.lora_linear_layer.down": ".v_proj.lora_A", ".out_proj.lora_linear_layer.up": ".out_proj.lora_B", ".out_proj.lora_linear_layer.down": ".out_proj.lora_A", ".lora_linear_layer.up": ".lora_B", ".lora_linear_layer.down": ".lora_A", "text_projection.lora.down.weight": "text_projection.lora_A.weight", "text_projection.lora.up.weight": "text_projection.lora_B.weight", } DIFFUSERS_OLD_TO_PEFT = { ".to_q_lora.up": ".q_proj.lora_B", ".to_q_lora.down": ".q_proj.lora_A", ".to_k_lora.up": ".k_proj.lora_B", ".to_k_lora.down": ".k_proj.lora_A", ".to_v_lora.up": ".v_proj.lora_B", ".to_v_lora.down": ".v_proj.lora_A", ".to_out_lora.up": ".out_proj.lora_B", ".to_out_lora.down": ".out_proj.lora_A", ".lora_linear_layer.up": ".lora_B", ".lora_linear_layer.down": ".lora_A", } PEFT_TO_DIFFUSERS = { ".q_proj.lora_B": ".q_proj.lora_linear_layer.up", ".q_proj.lora_A": ".q_proj.lora_linear_layer.down", ".k_proj.lora_B": ".k_proj.lora_linear_layer.up", ".k_proj.lora_A": ".k_proj.lora_linear_layer.down", ".v_proj.lora_B": ".v_proj.lora_linear_layer.up", ".v_proj.lora_A": ".v_proj.lora_linear_layer.down", ".out_proj.lora_B": ".out_proj.lora_linear_layer.up", ".out_proj.lora_A": ".out_proj.lora_linear_layer.down", "to_k.lora_A": "to_k.lora.down", "to_k.lora_B": "to_k.lora.up", "to_q.lora_A": "to_q.lora.down", "to_q.lora_B": "to_q.lora.up", "to_v.lora_A": "to_v.lora.down", "to_v.lora_B": "to_v.lora.up", "to_out.0.lora_A": "to_out.0.lora.down", "to_out.0.lora_B": "to_out.0.lora.up", } DIFFUSERS_OLD_TO_DIFFUSERS = { ".to_q_lora.up": ".q_proj.lora_linear_layer.up", ".to_q_lora.down": ".q_proj.lora_linear_layer.down", ".to_k_lora.up": ".k_proj.lora_linear_layer.up", ".to_k_lora.down": ".k_proj.lora_linear_layer.down", ".to_v_lora.up": ".v_proj.lora_linear_layer.up", ".to_v_lora.down": ".v_proj.lora_linear_layer.down", ".to_out_lora.up": ".out_proj.lora_linear_layer.up", ".to_out_lora.down": ".out_proj.lora_linear_layer.down", ".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector", ".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector", ".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector", ".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector", } PEFT_TO_KOHYA_SS = { "lora_A": "lora_down", "lora_B": "lora_up", # This is not a comprehensive dict as kohya format requires replacing `.` with `_` in keys, # adding prefixes and adding alpha values # Check `convert_state_dict_to_kohya` for more } PEFT_STATE_DICT_MAPPINGS = { StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT, StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT, } DIFFUSERS_STATE_DICT_MAPPINGS = { StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS, StateDictType.PEFT: PEFT_TO_DIFFUSERS, } KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS} KEYS_TO_ALWAYS_REPLACE = { ".processor.": ".", } def convert_state_dict(state_dict, mapping): r""" Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values. Args: state_dict (`dict[str, torch.Tensor]`): The state dict to convert. mapping (`dict[str, str]`): The mapping to use for conversion, the mapping should be a dictionary with the following structure: - key: the pattern to replace - value: the pattern to replace with Returns: converted_state_dict (`dict`) The converted state dict. """ converted_state_dict = {} for k, v in state_dict.items(): # First, filter out the keys that we always want to replace for pattern in KEYS_TO_ALWAYS_REPLACE.keys(): if pattern in k: new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern] k = k.replace(pattern, new_pattern) for pattern in mapping.keys(): if pattern in k: new_pattern = mapping[pattern] k = k.replace(pattern, new_pattern) break converted_state_dict[k] = v return converted_state_dict def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs): r""" Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now. Args: state_dict (`dict[str, torch.Tensor]`): The state dict to convert. original_type (`StateDictType`, *optional*): The original type of the state dict, if not provided, the method will try to infer it automatically. """ if original_type is None: # Old diffusers to PEFT if any("to_out_lora" in k for k in state_dict.keys()): original_type = StateDictType.DIFFUSERS_OLD elif any("lora_linear_layer" in k for k in state_dict.keys()): original_type = StateDictType.DIFFUSERS else: raise ValueError("Could not automatically infer state dict type") if original_type not in PEFT_STATE_DICT_MAPPINGS.keys(): raise ValueError(f"Original type {original_type} is not supported") mapping = PEFT_STATE_DICT_MAPPINGS[original_type] return convert_state_dict(state_dict, mapping) def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): r""" Converts a state dict to new diffusers format. The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will return the state dict as is. The method only supports the conversion from diffusers old, PEFT to diffusers new for now. Args: state_dict (`dict[str, torch.Tensor]`): The state dict to convert. original_type (`StateDictType`, *optional*): The original type of the state dict, if not provided, the method will try to infer it automatically. kwargs (`dict`, *args*): Additional arguments to pass to the method. - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in `get_peft_model_state_dict` method: https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 but we add it here in case we don't want to rely on that method. """ peft_adapter_name = kwargs.pop("adapter_name", None) if peft_adapter_name is not None: peft_adapter_name = "." + peft_adapter_name else: peft_adapter_name = "" if original_type is None: # Old diffusers to PEFT if any("to_out_lora" in k for k in state_dict.keys()): original_type = StateDictType.DIFFUSERS_OLD elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): original_type = StateDictType.PEFT elif any("lora_linear_layer" in k for k in state_dict.keys()): # nothing to do return state_dict else: raise ValueError("Could not automatically infer state dict type") if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): raise ValueError(f"Original type {original_type} is not supported") mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] return convert_state_dict(state_dict, mapping) def convert_unet_state_dict_to_peft(state_dict): r""" Converts a state dict from UNet format to diffusers format - i.e. by removing some keys """ mapping = UNET_TO_DIFFUSERS return convert_state_dict(state_dict, mapping) def convert_all_state_dict_to_peft(state_dict): r""" Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft` """ try: peft_dict = convert_state_dict_to_peft(state_dict) except Exception as e: if str(e) == "Could not automatically infer state dict type": peft_dict = convert_unet_state_dict_to_peft(state_dict) else: raise if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()): raise ValueError("Your LoRA was not converted to PEFT") return peft_dict def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): r""" Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc. The method only supports the conversion from PEFT to Kohya for now. Args: state_dict (`dict[str, torch.Tensor]`): The state dict to convert. original_type (`StateDictType`, *optional*): The original type of the state dict, if not provided, the method will try to infer it automatically. kwargs (`dict`, *args*): Additional arguments to pass to the method. - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in `get_peft_model_state_dict` method: https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 but we add it here in case we don't want to rely on that method. """ try: import torch except ImportError: logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.") raise peft_adapter_name = kwargs.pop("adapter_name", None) if peft_adapter_name is not None: peft_adapter_name = "." + peft_adapter_name else: peft_adapter_name = "" if original_type is None: if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): original_type = StateDictType.PEFT if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys(): raise ValueError(f"Original type {original_type} is not supported") # Use the convert_state_dict function with the appropriate mapping kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT]) kohya_ss_state_dict = {} # Additional logic for replacing header, alpha parameters `.` with `_` in all keys for kohya_key, weight in kohya_ss_partial_state_dict.items(): if "text_encoder_2." in kohya_key: kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.") elif "text_encoder." in kohya_key: kohya_key = kohya_key.replace("text_encoder.", "lora_te1.") elif "unet" in kohya_key: kohya_key = kohya_key.replace("unet", "lora_unet") elif "lora_magnitude_vector" in kohya_key: kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale") kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names kohya_ss_state_dict[kohya_key] = weight if "lora_down" in kohya_key: alpha_key = f'{kohya_key.split(".")[0]}.alpha' kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight)) return kohya_ss_state_dict ================================================ FILE: src/diffusers/utils/testing_utils.py ================================================ import functools import importlib import inspect import io import logging import multiprocessing import os import random import re import struct import sys import tempfile import time import unittest import urllib.parse from contextlib import contextmanager from io import BytesIO, StringIO from pathlib import Path from typing import Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import PIL.ImageOps import requests from numpy.linalg import norm from packaging import version from .import_utils import ( BACKENDS_MAPPING, is_compel_available, is_flax_available, is_note_seq_available, is_onnx_available, is_opencv_available, is_peft_available, is_timm_available, is_torch_available, is_torch_version, is_torchsde_available, is_transformers_available, ) from .logging import get_logger global_rng = random.Random() logger = get_logger(__name__) _required_peft_version = is_peft_available() and version.parse( version.parse(importlib.metadata.version("peft")).base_version ) > version.parse("0.5") _required_transformers_version = is_transformers_available() and version.parse( version.parse(importlib.metadata.version("transformers")).base_version ) > version.parse("4.33") USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version if is_torch_available(): import torch # Set a backend environment variable for any extra module import required for a custom accelerator if "DIFFUSERS_TEST_BACKEND" in os.environ: backend = os.environ["DIFFUSERS_TEST_BACKEND"] try: _ = importlib.import_module(backend) except ModuleNotFoundError as e: raise ModuleNotFoundError( f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \ to enable a specified backend.):\n{e}" ) from e if "DIFFUSERS_TEST_DEVICE" in os.environ: torch_device = os.environ["DIFFUSERS_TEST_DEVICE"] try: # try creating device to see if provided device is valid _ = torch.device(torch_device) except RuntimeError as e: raise RuntimeError( f"Unknown testing device specified by environment variable `DIFFUSERS_TEST_DEVICE`: {torch_device}" ) from e logger.info(f"torch_device overrode to {torch_device}") else: torch_device = "cuda" if torch.cuda.is_available() else "cpu" is_torch_higher_equal_than_1_12 = version.parse( version.parse(torch.__version__).base_version ) >= version.parse("1.12") if is_torch_higher_equal_than_1_12: # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details mps_backend_registered = hasattr(torch.backends, "mps") torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device def torch_all_close(a, b, *args, **kwargs): if not is_torch_available(): raise ValueError("PyTorch needs to be installed to use this function.") if not torch.allclose(a, b, *args, **kwargs): assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}." return True def numpy_cosine_similarity_distance(a, b): similarity = np.dot(a, b) / (norm(a) * norm(b)) distance = 1.0 - similarity.mean() return distance def print_tensor_test( tensor, limit_to_slices=None, max_torch_print=None, filename="test_corrections.txt", expected_tensor_name="expected_slice", ): if max_torch_print: torch.set_printoptions(threshold=10_000) test_name = os.environ.get("PYTEST_CURRENT_TEST") if not torch.is_tensor(tensor): tensor = torch.from_numpy(tensor) if limit_to_slices: tensor = tensor[0, -3:, -3:, -1] tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "") # format is usually: # expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161]) output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array") test_file, test_class, test_fn = test_name.split("::") test_fn = test_fn.split()[0] with open(filename, "a") as f: print("::".join([test_file, test_class, test_fn, output_str]), file=f) def get_tests_dir(append_path=None): """ Args: append_path: optional path to append to the tests dir path Return: The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is joined after the `tests` dir the former is provided. """ # this function caller's __file__ caller__file__ = inspect.stack()[1][1] tests_dir = os.path.abspath(os.path.dirname(caller__file__)) while not tests_dir.endswith("tests"): tests_dir = os.path.dirname(tests_dir) if append_path: return Path(tests_dir, append_path).as_posix() else: return tests_dir # Taken from the following PR: # https://github.com/huggingface/accelerate/pull/1964 def str_to_bool(value) -> int: """ Converts a string representation of truth to `True` (1) or `False` (0). True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; """ value = value.lower() if value in ("y", "yes", "t", "true", "on", "1"): return 1 elif value in ("n", "no", "f", "false", "off", "0"): return 0 else: raise ValueError(f"invalid truth value {value}") def parse_flag_from_env(key, default=False): try: value = os.environ[key] except KeyError: # KEY isn't set, default to `default`. _value = default else: # KEY is set, convert it to True or False. try: _value = str_to_bool(value) except ValueError: # More values are supported, but let's keep the message simple. raise ValueError(f"If set, {key} must be yes or no.") return _value _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) _run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False) _run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False) def floats_tensor(shape, scale=1.0, rng=None, name=None): """Creates a random float32 tensor""" if rng is None: rng = global_rng total_dims = 1 for dim in shape: total_dims *= dim values = [] for _ in range(total_dims): values.append(rng.random() * scale) return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() def slow(test_case): """ Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. """ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) def nightly(test_case): """ Decorator marking a test that runs nightly in the diffusers CI. Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them. """ return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case) def is_torch_compile(test_case): """ Decorator marking a test that runs compile tests in the diffusers CI. Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them. """ return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case) def require_torch(test_case): """ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. """ return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) def require_torch_2(test_case): """ Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. """ return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")( test_case ) def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( test_case ) # These decorators are for accelerator-specific behaviours that are not GPU-specific def require_torch_accelerator(test_case): """Decorator marking a test that requires an accelerator backend and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")( test_case ) def require_torch_multi_gpu(test_case): """ Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without multiple GPUs. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu" """ if not is_torch_available(): return unittest.skip("test requires PyTorch")(test_case) import torch return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case) def require_torch_accelerator_with_fp16(test_case): """Decorator marking a test that requires an accelerator with support for the FP16 data type.""" return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")( test_case ) def require_torch_accelerator_with_fp64(test_case): """Decorator marking a test that requires an accelerator with support for the FP64 data type.""" return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")( test_case ) def require_torch_accelerator_with_training(test_case): """Decorator marking a test that requires an accelerator with support for training.""" return unittest.skipUnless( is_torch_available() and backend_supports_training(torch_device), "test requires accelerator with training support", )(test_case) def skip_mps(test_case): """Decorator marking a test to skip if torch_device is 'mps'""" return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case) def require_flax(test_case): """ Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed """ return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) def require_compel(test_case): """ Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when the library is not installed. """ return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case) def require_onnxruntime(test_case): """ Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed. """ return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case) def require_note_seq(test_case): """ Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed. """ return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case) def require_torchsde(test_case): """ Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed. """ return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case) def require_peft_backend(test_case): """ Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and transformers. """ return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case) def require_timm(test_case): """ Decorator marking a test that requires timm. These tests are skipped when timm isn't installed. """ return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) def require_peft_version_greater(peft_version): """ Decorator marking a test that requires PEFT backend with a specific version, this would require some specific versions of PEFT and transformers. """ def decorator(test_case): correct_peft_version = is_peft_available() and version.parse( version.parse(importlib.metadata.version("peft")).base_version ) > version.parse(peft_version) return unittest.skipUnless( correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}" )(test_case) return decorator def require_accelerate_version_greater(accelerate_version): def decorator(test_case): correct_accelerate_version = is_peft_available() and version.parse( version.parse(importlib.metadata.version("accelerate")).base_version ) > version.parse(accelerate_version) return unittest.skipUnless( correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}." )(test_case) return decorator def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend """ return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case) def get_python_version(): sys_info = sys.version_info major, minor = sys_info.major, sys_info.minor return major, minor def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray: if isinstance(arry, str): if local_path is not None: # local_path can be passed to correct images of tests return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix() elif arry.startswith("http://") or arry.startswith("https://"): response = requests.get(arry) response.raise_for_status() arry = np.load(BytesIO(response.content)) elif os.path.isfile(arry): arry = np.load(arry) else: raise ValueError( f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path" ) elif isinstance(arry, np.ndarray): pass else: raise ValueError( "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a" " ndarray." ) return arry def load_pt(url: str): response = requests.get(url) response.raise_for_status() arry = torch.load(BytesIO(response.content)) return arry def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: """ Loads `image` to a PIL Image. Args: image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. Returns: `PIL.Image.Image`: A PIL Image. """ if isinstance(image, str): if image.startswith("http://") or image.startswith("https://"): image = PIL.Image.open(requests.get(image, stream=True).raw) elif os.path.isfile(image): image = PIL.Image.open(image) else: raise ValueError( f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" ) elif isinstance(image, PIL.Image.Image): image = image else: raise ValueError( "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." ) image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") return image def preprocess_image(image: PIL.Image, batch_size: int): w, h = image.size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) image = torch.from_numpy(image) return 2.0 * image - 1.0 def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str: if output_gif_path is None: output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name image[0].save( output_gif_path, save_all=True, append_images=image[1:], optimize=False, duration=100, loop=0, ) return output_gif_path @contextmanager def buffered_writer(raw_f): f = io.BufferedWriter(raw_f) yield f f.flush() def export_to_ply(mesh, output_ply_path: str = None): """ Write a PLY file for a mesh. """ if output_ply_path is None: output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name coords = mesh.verts.detach().cpu().numpy() faces = mesh.faces.cpu().numpy() rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1) with buffered_writer(open(output_ply_path, "wb")) as f: f.write(b"ply\n") f.write(b"format binary_little_endian 1.0\n") f.write(bytes(f"element vertex {len(coords)}\n", "ascii")) f.write(b"property float x\n") f.write(b"property float y\n") f.write(b"property float z\n") if rgb is not None: f.write(b"property uchar red\n") f.write(b"property uchar green\n") f.write(b"property uchar blue\n") if faces is not None: f.write(bytes(f"element face {len(faces)}\n", "ascii")) f.write(b"property list uchar int vertex_index\n") f.write(b"end_header\n") if rgb is not None: rgb = (rgb * 255.499).round().astype(int) vertices = [ (*coord, *rgb) for coord, rgb in zip( coords.tolist(), rgb.tolist(), ) ] format = struct.Struct("<3f3B") for item in vertices: f.write(format.pack(*item)) else: format = struct.Struct("<3f") for vertex in coords.tolist(): f.write(format.pack(*vertex)) if faces is not None: format = struct.Struct(" str: if is_opencv_available(): import cv2 else: raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name fourcc = cv2.VideoWriter_fourcc(*"mp4v") h, w, c = video_frames[0].shape video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) for i in range(len(video_frames)): img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) video_writer.write(img) return output_video_path def load_hf_numpy(path) -> np.ndarray: base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main" if not path.startswith("http://") and not path.startswith("https://"): path = os.path.join(base_url, urllib.parse.quote(path)) return load_numpy(path) # --- pytest conf functions --- # # to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once pytest_opt_registered = {} def pytest_addoption_shared(parser): """ This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` option. """ option = "--make-reports" if option not in pytest_opt_registered: parser.addoption( option, action="store", default=False, help="generate report files. The value of this option is used as a prefix to report names", ) pytest_opt_registered[option] = 1 def pytest_terminal_summary_main(tr, id): """ Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current directory. The report files are prefixed with the test suite name. This function emulates --duration and -rA pytest arguments. This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined there. Args: - tr: `terminalreporter` passed from `conftest.py` - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-` plugins and interfere. """ from _pytest.config import create_terminal_writer if not len(id): id = "tests" config = tr.config orig_writer = config.get_terminal_writer() orig_tbstyle = config.option.tbstyle orig_reportchars = tr.reportchars dir = "reports" Path(dir).mkdir(parents=True, exist_ok=True) report_files = { k: f"{dir}/{id}_{k}.txt" for k in [ "durations", "errors", "failures_long", "failures_short", "failures_line", "passes", "stats", "summary_short", "warnings", ] } # custom durations report # note: there is no need to call pytest --durations=XX to get this separate report # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 dlist = [] for replist in tr.stats.values(): for rep in replist: if hasattr(rep, "duration"): dlist.append(rep) if dlist: dlist.sort(key=lambda x: x.duration, reverse=True) with open(report_files["durations"], "w") as f: durations_min = 0.05 # sec f.write("slowest durations\n") for i, rep in enumerate(dlist): if rep.duration < durations_min: f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") break f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") def summary_failures_short(tr): # expecting that the reports were --tb=long (default) so we chop them off here to the last frame reports = tr.getreports("failed") if not reports: return tr.write_sep("=", "FAILURES SHORT STACK") for rep in reports: msg = tr._getfailureheadline(rep) tr.write_sep("_", msg, red=True, bold=True) # chop off the optional leading extra frames, leaving only the last one longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) tr._tw.line(longrepr) # note: not printing out any rep.sections to keep the report short # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. # pytest-instafail does that) # report failures with line/short/long styles config.option.tbstyle = "auto" # full tb with open(report_files["failures_long"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_failures() # config.option.tbstyle = "short" # short tb with open(report_files["failures_short"], "w") as f: tr._tw = create_terminal_writer(config, f) summary_failures_short(tr) config.option.tbstyle = "line" # one line per error with open(report_files["failures_line"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_failures() with open(report_files["errors"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_errors() with open(report_files["warnings"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_warnings() # normal warnings tr.summary_warnings() # final warnings tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) with open(report_files["passes"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_passes() with open(report_files["summary_short"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.short_test_summary() with open(report_files["stats"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_stats() # restore: tr._tw = orig_writer tr.reportchars = orig_reportchars config.option.tbstyle = orig_tbstyle # Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905 def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None): """ To decorate flaky tests. They will be retried on failures. Args: max_attempts (`int`, *optional*, defaults to 5): The maximum number of attempts to retry the flaky test. wait_before_retry (`float`, *optional*): If provided, will wait that number of seconds before retrying the test. description (`str`, *optional*): A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors, etc.) """ def decorator(test_func_ref): @functools.wraps(test_func_ref) def wrapper(*args, **kwargs): retry_count = 1 while retry_count < max_attempts: try: return test_func_ref(*args, **kwargs) except Exception as err: print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr) if wait_before_retry is not None: time.sleep(wait_before_retry) retry_count += 1 return test_func_ref(*args, **kwargs) return wrapper return decorator # Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787 def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): """ To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. Args: test_case (`unittest.TestCase`): The test that will run `target_func`. target_func (`Callable`): The function implementing the actual testing logic. inputs (`dict`, *optional*, defaults to `None`): The inputs that will be passed to `target_func` through an (input) queue. timeout (`int`, *optional*, defaults to `None`): The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. """ if timeout is None: timeout = int(os.environ.get("PYTEST_TIMEOUT", 600)) start_methohd = "spawn" ctx = multiprocessing.get_context(start_methohd) input_queue = ctx.Queue(1) output_queue = ctx.JoinableQueue(1) # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. input_queue.put(inputs, timeout=timeout) process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) process.start() # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents # the test to exit properly. try: results = output_queue.get(timeout=timeout) output_queue.task_done() except Exception as e: process.terminate() test_case.fail(e) process.join(timeout=timeout) if results["error"] is not None: test_case.fail(f'{results["error"]}') class CaptureLogger: """ Args: Context manager to capture `logging` streams logger: 'logging` logger object Returns: The captured output is available via `self.out` Example: ```python >>> from diffusers import logging >>> from diffusers.testing_utils import CaptureLogger >>> msg = "Testing 1, 2, 3" >>> logging.set_verbosity_info() >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py") >>> with CaptureLogger(logger) as cl: ... logger.info(msg) >>> assert cl.out, msg + "\n" ``` """ def __init__(self, logger): self.logger = logger self.io = StringIO() self.sh = logging.StreamHandler(self.io) self.out = "" def __enter__(self): self.logger.addHandler(self.sh) return self def __exit__(self, *exc): self.logger.removeHandler(self.sh) self.out = self.io.getvalue() def __repr__(self): return f"captured: {self.out}\n" def enable_full_determinism(): """ Helper function for reproducible behavior during distributed training. See - https://pytorch.org/docs/stable/notes/randomness.html for pytorch """ # Enable PyTorch deterministic mode. This potentially requires either the environment # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, # depending on the CUDA version, so we set them both here os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" torch.use_deterministic_algorithms(True) # Enable CUDNN deterministic mode torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False def disable_full_determinism(): os.environ["CUDA_LAUNCH_BLOCKING"] = "0" os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" torch.use_deterministic_algorithms(False) # Utils for custom and alternative accelerator devices def _is_torch_fp16_available(device): if not is_torch_available(): return False import torch device = torch.device(device) try: x = torch.zeros((2, 2), dtype=torch.float16).to(device) _ = torch.mul(x, x) return True except Exception as e: if device.type == "cuda": raise ValueError( f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}" ) return False def _is_torch_fp64_available(device): if not is_torch_available(): return False import torch device = torch.device(device) try: x = torch.zeros((2, 2), dtype=torch.float64).to(device) _ = torch.mul(x, x) return True except Exception as e: if device.type == "cuda": raise ValueError( f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}" ) return False # Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch if is_torch_available(): # Behaviour flags BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True} # Function definitions BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None} BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0} BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} # This dispatches a defined function according to the accelerator from the function definitions. def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs): if device not in dispatch_table: return dispatch_table["default"](*args, **kwargs) fn = dispatch_table[device] # Some device agnostic functions return values. Need to guard against 'None' instead at # user level if fn is None: return None return fn(*args, **kwargs) # These are callables which automatically dispatch the function specific to the accelerator def backend_manual_seed(device: str, seed: int): return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) def backend_empty_cache(device: str): return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) def backend_device_count(device: str): return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) # These are callables which return boolean behaviour flags and can be used to specify some # device agnostic alternative where the feature is unsupported. def backend_supports_training(device: str): if not is_torch_available(): return False if device not in BACKEND_SUPPORTS_TRAINING: device = "default" return BACKEND_SUPPORTS_TRAINING[device] # Guard for when Torch is not available if is_torch_available(): # Update device function dict mapping def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str): try: # Try to import the function directly spec_fn = getattr(device_spec_module, attribute_name) device_fn_dict[torch_device] = spec_fn except AttributeError as e: # If the function doesn't exist, and there is no default, throw an error if "default" not in device_fn_dict: raise AttributeError( f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found." ) from e if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ: device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"] if not Path(device_spec_path).is_file(): raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}") try: import_name = device_spec_path[: device_spec_path.index(".py")] except ValueError as e: raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e device_spec_module = importlib.import_module(import_name) try: device_name = device_spec_module.DEVICE_NAME except AttributeError: raise AttributeError("Device spec file did not contain `DEVICE_NAME`") if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name: msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n" msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name." raise ValueError(msg) torch_device = device_name # Add one entry here for each `BACKEND_*` dictionary. update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN") update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING") ================================================ FILE: src/diffusers/utils/torch_utils.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch utilities: Utilities related to PyTorch """ from typing import List, Optional, Tuple, Union from . import logging from .import_utils import is_torch_available, is_torch_version if is_torch_available(): import torch from torch.fft import fftn, fftshift, ifftn, ifftshift logger = logging.get_logger(__name__) # pylint: disable=invalid-name try: from torch._dynamo import allow_in_graph as maybe_allow_in_graph except (ImportError, ModuleNotFoundError): def maybe_allow_in_graph(cls): return cls def randn_tensor( shape: Union[Tuple, List], generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, device: Optional["torch.device"] = None, dtype: Optional["torch.dtype"] = None, layout: Optional["torch.layout"] = None, ): """A helper function to create random tensors on the desired `device` with the desired `dtype`. When passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor is always created on the CPU. """ # device on which tensor is created defaults to device rand_device = device batch_size = shape[0] layout = layout or torch.strided device = device or torch.device("cpu") if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" if device != "mps": logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" f" slighly speed up this function by passing a generator that was created on the {device} device." ) elif gen_device_type != device.type and gen_device_type == "cuda": raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") # make sure generator list of length 1 is treated like a non-list if isinstance(generator, list) and len(generator) == 1: generator = generator[0] if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents def is_compiled_module(module) -> bool: """Check whether the module was compiled with torch.compile()""" if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): return False return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). This version of the method comes from here: https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 """ x = x_in B, C, H, W = x.shape # Non-power of 2 images must be float32 if (W & (W - 1)) != 0 or (H & (H - 1)) != 0: x = x.to(dtype=torch.float32) # FFT x_freq = fftn(x, dim=(-2, -1)) x_freq = fftshift(x_freq, dim=(-2, -1)) B, C, H, W = x_freq.shape mask = torch.ones((B, C, H, W), device=x.device) crow, ccol = H // 2, W // 2 mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale x_freq = x_freq * mask # IFFT x_freq = ifftshift(x_freq, dim=(-2, -1)) x_filtered = ifftn(x_freq, dim=(-2, -1)).real return x_filtered.to(dtype=x_in.dtype) def apply_freeu( resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs ) -> Tuple["torch.Tensor", "torch.Tensor"]: """Applies the FreeU mechanism as introduced in https: //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. Args: resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied. hidden_states (`torch.Tensor`): Inputs to the underlying block. res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block. s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ if resolution_idx == 0: num_half_channels = hidden_states.shape[1] // 2 hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"] res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"]) if resolution_idx == 1: num_half_channels = hidden_states.shape[1] // 2 hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"] res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"]) return hidden_states, res_hidden_states ================================================ FILE: src/diffusers/utils/versions.py ================================================ # Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Utilities for working with package versions """ import importlib.metadata import operator import re import sys from typing import Optional from packaging import version ops = { "<": operator.lt, "<=": operator.le, "==": operator.eq, "!=": operator.ne, ">=": operator.ge, ">": operator.gt, } def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint): if got_ver is None or want_ver is None: raise ValueError( f"Unable to compare versions for {requirement}: need={want_ver} found={got_ver}. This is unusual. Consider" f" reinstalling {pkg}." ) if not ops[op](version.parse(got_ver), version.parse(want_ver)): raise ImportError( f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}" ) def require_version(requirement: str, hint: Optional[str] = None) -> None: """ Perform a runtime check of the dependency versions, using the exact same syntax used by pip. The installed module version comes from the *site-packages* dir via *importlib.metadata*. Args: requirement (`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy" hint (`str`, *optional*): what suggestion to print in case of requirements not being met Example: ```python require_version("pandas>1.1.2") require_version("numpy>1.18.5", "this is important to have for whatever reason") ```""" hint = f"\n{hint}" if hint is not None else "" # non-versioned check if re.match(r"^[\w_\-\d]+$", requirement): pkg, op, want_ver = requirement, None, None else: match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement) if not match: raise ValueError( "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but" f" got {requirement}" ) pkg, want_full = match[0] want_range = want_full.split(",") # there could be multiple requirements wanted = {} for w in want_range: match = re.findall(r"^([\s!=<>]{1,2})(.+)", w) if not match: raise ValueError( "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23," f" but got {requirement}" ) op, want_ver = match[0] wanted[op] = want_ver if op not in ops: raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}") # special case if pkg == "python": got_ver = ".".join([str(x) for x in sys.version_info[:3]]) for op, want_ver in wanted.items(): _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) return # check if any version is installed try: got_ver = importlib.metadata.version(pkg) except importlib.metadata.PackageNotFoundError: raise importlib.metadata.PackageNotFoundError( f"The '{requirement}' distribution was not found and is required by this application. {hint}" ) # check that the right version is installed if version number or a range was provided if want_ver is not None: for op, want_ver in wanted.items(): _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) def require_version_core(requirement): """require_version wrapper which emits a core-specific hint on failure""" hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git main" return require_version(requirement, hint) ================================================ FILE: src/diffusers/video_processor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import List, Optional, Union import numpy as np import PIL import torch from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor: r""" Preprocesses input video(s). Args: video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): The input video. It can be one of the following: * List of the PIL images. * List of list of PIL images. * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, num_channels)`. * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, width)`. height (`int`, *optional*, defaults to `None`): The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to get default height. width (`int`, *optional*`, defaults to `None`): The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get the default width. """ if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: warnings.warn( "Passing `video` as a list of 5d np.ndarray is deprecated." "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", FutureWarning, ) video = np.concatenate(video, axis=0) if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: warnings.warn( "Passing `video` as a list of 5d torch.Tensor is deprecated." "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", FutureWarning, ) video = torch.cat(video, axis=0) # ensure the input is a list of videos: # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) # - if it is is a single video, it is convereted to a list of one video. if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: video = list(video) elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): video = [video] elif isinstance(video, list) and is_valid_image_imagelist(video[0]): video = video else: raise ValueError( "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" ) video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) # move the number of channels before the number of frames. video = video.permute(0, 2, 1, 3, 4) return video def postprocess_video( self, video: torch.Tensor, output_type: str = "np" ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: r""" Converts a video tensor to a list of frames for export. Args: video (`torch.Tensor`): The video as a tensor. output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. """ batch_size = video.shape[0] outputs = [] for batch_idx in range(batch_size): batch_vid = video[batch_idx].permute(1, 0, 2, 3) batch_output = self.postprocess(batch_vid, output_type) outputs.append(batch_output) if output_type == "np": outputs = np.stack(outputs) elif output_type == "pt": outputs = torch.stack(outputs) elif not output_type == "pil": raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") return outputs